Fix issue with regular torch version.
This commit is contained in:
@@ -390,7 +390,11 @@ def unet_inital_load_device(parameters, dtype):
|
|||||||
return torch_dev
|
return torch_dev
|
||||||
|
|
||||||
cpu_dev = torch.device("cpu")
|
cpu_dev = torch.device("cpu")
|
||||||
model_size = dtype.itemsize * parameters
|
dtype_size = 4
|
||||||
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||||
|
dtype_size = 2
|
||||||
|
|
||||||
|
model_size = dtype_size * parameters
|
||||||
|
|
||||||
mem_dev = get_free_memory(torch_dev)
|
mem_dev = get_free_memory(torch_dev)
|
||||||
mem_cpu = get_free_memory(cpu_dev)
|
mem_cpu = get_free_memory(cpu_dev)
|
||||||
|
|||||||
Reference in New Issue
Block a user