Lower memory usage for loras in lowvram mode at the cost of perf.

This commit is contained in:
comfyanonymous
2024-03-13 19:04:41 -04:00
parent eda8704386
commit db8b59ecff
3 changed files with 101 additions and 48 deletions

View File

@@ -272,7 +272,6 @@ def module_size(module):
class LoadedModel:
def __init__(self, model):
self.model = model
self.model_accelerated = False
self.device = model.load_device
def model_memory(self):
@@ -285,52 +284,27 @@ class LoadedModel:
return self.model_memory()
def model_load(self, lowvram_model_memory=0):
patch_model_to = None
if lowvram_model_memory == 0:
patch_model_to = self.device
patch_model_to = self.device
self.model.model_patches_to(self.device)
self.model.model_patches_to(self.model.model_dtype())
try:
self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU
if lowvram_model_memory > 0:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e
if lowvram_model_memory > 0:
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
mem_counter = 0
for m in self.real_model.modules():
if hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
module_mem = module_size(m)
if mem_counter + module_mem < lowvram_model_memory:
m.to(self.device)
mem_counter += module_mem
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
m.to(self.device)
mem_counter += module_size(m)
logging.warning("lowvram: loaded module regularly {}".format(m))
self.model_accelerated = True
if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
return self.real_model
def model_unload(self):
if self.model_accelerated:
for m in self.real_model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device)
self.model.model_patches_to(self.model.offload_device)