Load flux t5 in fp8 if weights are in fp8.
This commit is contained in:
@@ -661,6 +661,17 @@ def supports_cast(device, dtype): #TODO
|
||||
return True
|
||||
return False
|
||||
|
||||
def pick_weight_dtype(dtype, fallback_dtype, device=None):
|
||||
if dtype is None:
|
||||
dtype = fallback_dtype
|
||||
elif dtype_size(dtype) > dtype_size(fallback_dtype):
|
||||
dtype = fallback_dtype
|
||||
|
||||
if not supports_cast(device, dtype):
|
||||
dtype = fallback_dtype
|
||||
|
||||
return dtype
|
||||
|
||||
def device_supports_non_blocking(device):
|
||||
if is_device_mps(device):
|
||||
return False #pytorch bug? mps doesn't support non blocking
|
||||
|
||||
Reference in New Issue
Block a user