Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time
This commit is contained in:
@@ -65,7 +65,7 @@ class _HookRef:
|
||||
pass
|
||||
|
||||
|
||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
'''Example for how should_register function should look like.'''
|
||||
return True
|
||||
|
||||
@@ -114,10 +114,10 @@ class Hook:
|
||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
||||
return c
|
||||
|
||||
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
||||
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
return self.custom_should_register(self, model, model_options, target_dict, registered)
|
||||
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||
|
||||
def on_apply(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||
@@ -154,7 +154,7 @@ class WeightHook(Hook):
|
||||
def strength_clip(self):
|
||||
return self._strength_clip * self.strength
|
||||
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
if not self.should_register(model, model_options, target_dict, registered):
|
||||
return False
|
||||
weights = None
|
||||
@@ -178,7 +178,7 @@ class WeightHook(Hook):
|
||||
else:
|
||||
weights = self.weights
|
||||
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||
registered.append(self)
|
||||
registered.add(self)
|
||||
return True
|
||||
# TODO: add logs about any keys that were not applied
|
||||
|
||||
@@ -212,11 +212,12 @@ class AddModelsHook(Hook):
|
||||
|
||||
Note, value of hook_scope is ignored and is treated as AllConditioning.
|
||||
'''
|
||||
def __init__(self, key: str=None, models: list[ModelPatcher]=None):
|
||||
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
|
||||
super().__init__(hook_type=EnumHookType.AddModels)
|
||||
self.key = key
|
||||
self.models = models
|
||||
self.key = key
|
||||
self.append_when_same = True
|
||||
'''Curently does nothing.'''
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
@@ -227,9 +228,10 @@ class AddModelsHook(Hook):
|
||||
c.append_when_same = self.append_when_same
|
||||
return c
|
||||
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
if not self.should_register(model, model_options, target_dict, registered):
|
||||
return False
|
||||
registered.add(self)
|
||||
return True
|
||||
|
||||
class TransformerOptionsHook(Hook):
|
||||
@@ -247,14 +249,17 @@ class TransformerOptionsHook(Hook):
|
||||
c.transformers_dict = self.transformers_dict
|
||||
return c
|
||||
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
|
||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
if not self.should_register(model, model_options, target_dict, registered):
|
||||
return False
|
||||
add_model_options = {"transformer_options": self.transformers_dict}
|
||||
# TODO: call .to on patches/anything else in transformer_options that is expected to do something
|
||||
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
|
||||
if self.hook_scope == EnumHookScope.AllConditioning:
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||
registered.append(self)
|
||||
add_model_options = {"transformer_options": self.transformers_dict,
|
||||
"to_load_options": self.transformers_dict}
|
||||
else:
|
||||
add_model_options = {"to_load_options": self.transformers_dict}
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||
registered.add(self)
|
||||
return True
|
||||
|
||||
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
||||
@@ -295,6 +300,9 @@ class HookGroup:
|
||||
self.hooks: list[Hook] = []
|
||||
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.hooks)
|
||||
|
||||
def add(self, hook: Hook):
|
||||
if hook not in self.hooks:
|
||||
self.hooks.append(hook)
|
||||
|
||||
Reference in New Issue
Block a user