Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type

This commit is contained in:
Jedrzej Kosinski
2025-01-04 02:04:07 -06:00
parent 776aa734e1
commit 111fd0cadf
4 changed files with 53 additions and 57 deletions

View File

@@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type):
models += [c[model_type]]
return models
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup):
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
cnets: list[ControlBase] = []
for c in cond:
if 'hooks' in c:
for hook in c['hooks'].hooks:
hook: comfy.hooks.Hook
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
full_hooks.add(hook)
if 'control' in c:
cnets.append(c['control'])
@@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
if extra_hooks is not None:
for hook in extra_hooks.hooks:
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
full_hooks.add(hook)
return hooks_dict
return full_hooks
def convert_cond(cond):
out = []
@@ -73,7 +70,7 @@ def get_additional_models(conds, dtype):
cnets: list[ControlBase] = []
gligen = []
add_models = []
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
hooks = comfy.hooks.HookGroup()
for k in conds:
cnets += get_models_from_cond(conds[k], "control")
@@ -90,7 +87,10 @@ def get_additional_models(conds, dtype):
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
hook_models = []
for x in hooks.get_type(comfy.hooks.EnumHookType.AddModels):
x: comfy.hooks.AddModelsHook
hook_models.extend(x.models)
models = control_models + gligen + add_models + hook_models
return models, inference_memory
@@ -124,7 +124,7 @@ def cleanup_models(conds, models):
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
# check for hooks in conds - if not registered, see if can be applied
hooks = {}
hooks = comfy.hooks.HookGroup()
for k in conds:
get_hooks_from_cond(conds[k], hooks)
# add wrappers and callbacks from ModelPatcher to transformer_options