Try to fix memory issue with lora.

This commit is contained in:
comfyanonymous
2023-07-22 21:26:45 -04:00
parent 67be7eb81d
commit 22f29d66ca
2 changed files with 12 additions and 5 deletions

View File

@@ -281,19 +281,23 @@ def load_model_gpu(model):
vram_set_state = VRAMState.LOW_VRAM
real_model = model.model
patch_model_to = None
if vram_set_state == VRAMState.DISABLED:
pass
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
model_accelerated = False
real_model.to(torch_dev)
patch_model_to = torch_dev
try:
real_model = model.patch_model()
real_model = model.patch_model(device_to=patch_model_to)
except Exception as e:
model.unpatch_model()
unload_model()
raise e
if patch_model_to is not None:
real_model.to(torch_dev)
if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)