Load flux t5 in fp8 if weights are in fp8.

This commit is contained in:
comfyanonymous
2024-08-01 11:05:56 -04:00
parent 8d34211a7a
commit 5f98de7697
4 changed files with 29 additions and 12 deletions

View File

@@ -661,6 +661,17 @@ def supports_cast(device, dtype): #TODO
return True
return False
def pick_weight_dtype(dtype, fallback_dtype, device=None):
if dtype is None:
dtype = fallback_dtype
elif dtype_size(dtype) > dtype_size(fallback_dtype):
dtype = fallback_dtype
if not supports_cast(device, dtype):
dtype = fallback_dtype
return dtype
def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking