Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type

This commit is contained in:
Jedrzej Kosinski
2025-01-04 02:04:07 -06:00
parent 776aa734e1
commit 111fd0cadf
4 changed files with 53 additions and 57 deletions

View File

@@ -940,16 +940,16 @@ class ModelPatcher:
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target_dict: dict[str], model_options: dict=None):
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None):
self.restore_hook_patches()
registered_hooks: list[comfy.hooks.Hook] = []
# handle WrapperHooks, if model_options provided
# handle TransformerOptionsHooks, if model_options provided
if model_options is not None:
for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}):
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
# handle WeightHooks
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}):
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
if hook.hook_ref not in self.hook_patches:
weight_hooks_to_register.append(hook)
if len(weight_hooks_to_register) > 0:
@@ -958,7 +958,7 @@ class ModelPatcher:
for hook in weight_hooks_to_register:
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
callback(self, hooks_dict, target_dict)
callback(self, hooks, target_dict)
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
with self.use_ejected():