Merge branch 'master' into worksplit-multigpu

This commit is contained in:
Jedrzej Kosinski
2025-02-24 19:23:16 -06:00
15 changed files with 232 additions and 60 deletions

View File

@@ -245,7 +245,7 @@ def is_amd():
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
MIN_WEIGHT_MEMORY_RATIO = 0.1
MIN_WEIGHT_MEMORY_RATIO = 0.0
ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention:
@@ -281,9 +281,12 @@ if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try:
if is_nvidia() and args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
except:
pass
@@ -710,6 +713,10 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
if model_params * 2 > free_model_memory:
return fp8_dtype
if PRIORITIZE_FP16:
if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params):
return torch.float16
for dt in supported_dtypes:
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
if torch.float16 in supported_dtypes:
@@ -972,11 +979,11 @@ def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
macos_version = mac_version()
if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS
upcast = True
if upcast:
return torch.float32
return {torch.float16: torch.float32}
else:
return None
@@ -1050,8 +1057,6 @@ def is_directml_enabled():
return False
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled
if device is not None:
if is_device_cpu(device):
return False
@@ -1062,8 +1067,8 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if FORCE_FP32:
return False
if directml_enabled:
return False
if is_directml_enabled():
return True
if (device is not None and is_device_mps(device)) or mps_mode():
return True
@@ -1150,7 +1155,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
bf16_works = torch.cuda.is_bf16_supported()
if bf16_works or manual_cast:
if bf16_works and manual_cast:
free_model_memory = maximum_vram_for_weights(device)
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True