Greatly improve lowvram sampling speed by getting rid of accelerate.
Let me know if this breaks anything.
This commit is contained in:
@@ -218,15 +218,8 @@ if args.force_fp16:
|
||||
FORCE_FP16 = True
|
||||
|
||||
if lowvram_available:
|
||||
try:
|
||||
import accelerate
|
||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||
vram_state = set_vram_to
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
|
||||
lowvram_available = False
|
||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||
vram_state = set_vram_to
|
||||
|
||||
|
||||
if cpu_state != CPUState.GPU:
|
||||
@@ -298,8 +291,20 @@ class LoadedModel:
|
||||
|
||||
if lowvram_model_memory > 0:
|
||||
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
|
||||
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
||||
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
|
||||
mem_counter = 0
|
||||
for m in self.real_model.modules():
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
module_mem = 0
|
||||
sd = m.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
module_mem += t.nelement() * t.element_size()
|
||||
if mem_counter + module_mem < lowvram_model_memory:
|
||||
m.to(self.device)
|
||||
mem_counter += module_mem
|
||||
|
||||
self.model_accelerated = True
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||
@@ -309,7 +314,11 @@ class LoadedModel:
|
||||
|
||||
def model_unload(self):
|
||||
if self.model_accelerated:
|
||||
accelerate.hooks.remove_hook_from_submodules(self.real_model)
|
||||
for m in self.real_model.modules():
|
||||
if hasattr(m, "prev_comfy_cast_weights"):
|
||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||
del m.prev_comfy_cast_weights
|
||||
|
||||
self.model_accelerated = False
|
||||
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
@@ -402,14 +411,14 @@ def load_models_gpu(models, memory_required=0):
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
||||
vram_set_state = VRAMState.LOW_VRAM
|
||||
else:
|
||||
lowvram_model_memory = 0
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 256 * 1024 * 1024
|
||||
lowvram_model_memory = 64 * 1024 * 1024
|
||||
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
@@ -566,6 +575,11 @@ def supports_dtype(device, dtype): #TODO
|
||||
return True
|
||||
return False
|
||||
|
||||
def device_supports_non_blocking(device):
|
||||
if is_device_mps(device):
|
||||
return False #pytorch bug? mps doesn't support non blocking
|
||||
return True
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
device_supports_cast = False
|
||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||
@@ -576,9 +590,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
elif is_intel_xpu():
|
||||
device_supports_cast = True
|
||||
|
||||
non_blocking = True
|
||||
if is_device_mps(device):
|
||||
non_blocking = False #pytorch bug? mps doesn't support non blocking
|
||||
non_blocking = device_supports_non_blocking(device)
|
||||
|
||||
if device_supports_cast:
|
||||
if copy:
|
||||
@@ -742,11 +754,7 @@ def soft_empty_cache(force=False):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def resolve_lowvram_weight(weight, model, key):
|
||||
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
|
||||
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
|
||||
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
|
||||
weight = op._hf_hook.weights_map[key_split[-1]]
|
||||
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
||||
return weight
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
|
||||
Reference in New Issue
Block a user