Refactor code so model can be a dtype other than fp32 or fp16.

This commit is contained in:
comfyanonymous
2023-10-13 14:35:21 -04:00
parent fee3b0c070
commit 9a55dadb4c
6 changed files with 39 additions and 41 deletions

View File

@@ -448,6 +448,11 @@ def unet_inital_load_device(parameters, dtype):
else:
return cpu_dev
def unet_dtype(device=None, model_params=0):
if should_use_fp16(device=device, model_params=model_params):
return torch.float16
return torch.float32
def text_encoder_offload_device():
if args.gpu_only:
return get_torch_device()