Use faster manual cast for fp8 in unet.

This commit is contained in:
comfyanonymous
2023-12-11 18:24:44 -05:00
parent ab93abd4b2
commit ba07cb748e
5 changed files with 48 additions and 12 deletions

View File

@@ -474,6 +474,20 @@ def unet_dtype(device=None, model_params=0):
return torch.float16
return torch.float32
# None means no manual cast
def unet_manual_cast(weight_dtype, inference_device):
if weight_dtype == torch.float32:
return None
fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False)
if fp16_supported and weight_dtype == torch.float16:
return None
if fp16_supported:
return torch.float16
else:
return torch.float32
def text_encoder_offload_device():
if args.gpu_only:
return get_torch_device()
@@ -538,7 +552,7 @@ def get_autocast_device(dev):
def supports_dtype(device, dtype): #TODO
if dtype == torch.float32:
return True
if torch.device("cpu") == device:
if is_device_cpu(device):
return False
if dtype == torch.float16:
return True