Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type
This commit is contained in:
@@ -940,16 +940,16 @@ class ModelPatcher:
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
|
||||
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target_dict: dict[str], model_options: dict=None):
|
||||
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None):
|
||||
self.restore_hook_patches()
|
||||
registered_hooks: list[comfy.hooks.Hook] = []
|
||||
# handle WrapperHooks, if model_options provided
|
||||
# handle TransformerOptionsHooks, if model_options provided
|
||||
if model_options is not None:
|
||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}):
|
||||
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
|
||||
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
|
||||
# handle WeightHooks
|
||||
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
||||
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}):
|
||||
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
|
||||
if hook.hook_ref not in self.hook_patches:
|
||||
weight_hooks_to_register.append(hook)
|
||||
if len(weight_hooks_to_register) > 0:
|
||||
@@ -958,7 +958,7 @@ class ModelPatcher:
|
||||
for hook in weight_hooks_to_register:
|
||||
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
||||
callback(self, hooks_dict, target_dict)
|
||||
callback(self, hooks, target_dict)
|
||||
|
||||
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
||||
with self.use_ejected():
|
||||
|
||||
Reference in New Issue
Block a user