Use faster manual cast for fp8 in unet.
This commit is contained in:
@@ -474,6 +474,20 @@ def unet_dtype(device=None, model_params=0):
|
||||
return torch.float16
|
||||
return torch.float32
|
||||
|
||||
# None means no manual cast
|
||||
def unet_manual_cast(weight_dtype, inference_device):
|
||||
if weight_dtype == torch.float32:
|
||||
return None
|
||||
|
||||
fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False)
|
||||
if fp16_supported and weight_dtype == torch.float16:
|
||||
return None
|
||||
|
||||
if fp16_supported:
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
def text_encoder_offload_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
@@ -538,7 +552,7 @@ def get_autocast_device(dev):
|
||||
def supports_dtype(device, dtype): #TODO
|
||||
if dtype == torch.float32:
|
||||
return True
|
||||
if torch.device("cpu") == device:
|
||||
if is_device_cpu(device):
|
||||
return False
|
||||
if dtype == torch.float16:
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user