Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time

This commit is contained in:
Jedrzej Kosinski
2025-01-05 21:07:02 -06:00
parent db2d7ad9ba
commit 8270ff312f
4 changed files with 119 additions and 35 deletions

View File

@@ -65,7 +65,7 @@ class _HookRef:
pass
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
'''Example for how should_register function should look like.'''
return True
@@ -114,10 +114,10 @@ class Hook:
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
return c
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
return self.custom_should_register(self, model, model_options, target_dict, registered)
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
def on_apply(self, model: ModelPatcher, transformer_options: dict[str]):
@@ -154,7 +154,7 @@ class WeightHook(Hook):
def strength_clip(self):
return self._strength_clip * self.strength
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
weights = None
@@ -178,7 +178,7 @@ class WeightHook(Hook):
else:
weights = self.weights
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
registered.append(self)
registered.add(self)
return True
# TODO: add logs about any keys that were not applied
@@ -212,11 +212,12 @@ class AddModelsHook(Hook):
Note, value of hook_scope is ignored and is treated as AllConditioning.
'''
def __init__(self, key: str=None, models: list[ModelPatcher]=None):
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
super().__init__(hook_type=EnumHookType.AddModels)
self.key = key
self.models = models
self.key = key
self.append_when_same = True
'''Curently does nothing.'''
def clone(self, subtype: Callable=None):
if subtype is None:
@@ -227,9 +228,10 @@ class AddModelsHook(Hook):
c.append_when_same = self.append_when_same
return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
registered.add(self)
return True
class TransformerOptionsHook(Hook):
@@ -247,14 +249,17 @@ class TransformerOptionsHook(Hook):
c.transformers_dict = self.transformers_dict
return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
add_model_options = {"transformer_options": self.transformers_dict}
# TODO: call .to on patches/anything else in transformer_options that is expected to do something
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
if self.hook_scope == EnumHookScope.AllConditioning:
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.append(self)
add_model_options = {"transformer_options": self.transformers_dict,
"to_load_options": self.transformers_dict}
else:
add_model_options = {"to_load_options": self.transformers_dict}
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.add(self)
return True
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
@@ -295,6 +300,9 @@ class HookGroup:
self.hooks: list[Hook] = []
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
def __len__(self):
return len(self.hooks)
def add(self, hook: Hook):
if hook not in self.hooks:
self.hooks.append(hook)