Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time

This commit is contained in:
Jedrzej Kosinski
2025-01-05 21:07:02 -06:00
parent db2d7ad9ba
commit 8270ff312f
4 changed files with 119 additions and 35 deletions

View File

@@ -819,9 +819,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
return len(hooks_set)
def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
'''
If any patches from hooks, wrappers, or callbacks have .to to be called, call it.
'''
if model_options is None:
return
to_load_options = model_options.get("to_load_options", None)
if to_load_options is None:
return
casts = []
if device is not None:
casts.append(device)
if dtype is not None:
casts.append(dtype)
# if nothing to apply, do nothing
if len(casts) == 0:
return
# Try to call .to on patches
if "patches" in to_load_options:
patches = to_load_options["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
for cast in casts:
patch_list[i] = patch_list[i].to(cast)
if "patches_replace" in to_load_options:
patches = to_load_options["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
for cast in casts:
patch_list[k] = patch_list[k].to(cast)
# Try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
if wc_name in to_load_options:
wc: dict[str, list] = to_load_options[wc_name]
for wc_dict in wc.values():
for wc_list in wc_dict.values():
for i in range(len(wc_list)):
if hasattr(wc_list[i], "to"):
for cast in casts:
wc_list[i] = wc_list[i].to(cast)
class CFGGuider:
def __init__(self, model_patcher):
self.model_patcher: 'ModelPatcher' = model_patcher
def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher
self.model_options = model_patcher.model_options
self.original_conds = {}
self.cfg = 1.0
@@ -861,7 +910,7 @@ class CFGGuider:
return self.inner_model.process_latent_out(samples.to(torch.float32))
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
if denoise_mask is not None:
@@ -870,6 +919,7 @@ class CFGGuider:
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
try:
self.model_patcher.pre_run()
@@ -906,6 +956,7 @@ class CFGGuider:
)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
self.model_options = orig_model_options
self.model_patcher.hook_mode = orig_hook_mode
self.model_patcher.restore_hook_patches()