Compare commits
11 Commits
worksplit-
...
v3-definit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50603859ab | ||
|
|
0d185b721f | ||
|
|
8642757971 | ||
|
|
de86d8e32b | ||
|
|
8b331c5ca2 | ||
|
|
937d2d5325 | ||
|
|
0400497d5e | ||
|
|
5f0e04e2d7 | ||
|
|
96c2e3856d | ||
|
|
880f756dc1 | ||
|
|
4480ed488e |
@@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
|
|||||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||||
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use.")
|
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||||
cm_group = parser.add_mutually_exclusive_group()
|
cm_group = parser.add_mutually_exclusive_group()
|
||||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
||||||
|
|||||||
@@ -15,14 +15,13 @@
|
|||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import copy
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_detection
|
import comfy.model_detection
|
||||||
@@ -37,7 +36,7 @@ import comfy.cldm.mmdit
|
|||||||
import comfy.ldm.hydit.controlnet
|
import comfy.ldm.hydit.controlnet
|
||||||
import comfy.ldm.flux.controlnet
|
import comfy.ldm.flux.controlnet
|
||||||
import comfy.cldm.dit_embedder
|
import comfy.cldm.dit_embedder
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.hooks import HookGroup
|
from comfy.hooks import HookGroup
|
||||||
|
|
||||||
@@ -64,18 +63,6 @@ class StrengthType(Enum):
|
|||||||
CONSTANT = 1
|
CONSTANT = 1
|
||||||
LINEAR_UP = 2
|
LINEAR_UP = 2
|
||||||
|
|
||||||
class ControlIsolation:
|
|
||||||
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
|
|
||||||
def __init__(self, control: ControlBase):
|
|
||||||
self.control = control
|
|
||||||
self.orig_previous_controlnet = control.previous_controlnet
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.control.previous_controlnet = None
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
|
||||||
self.control.previous_controlnet = self.orig_previous_controlnet
|
|
||||||
|
|
||||||
class ControlBase:
|
class ControlBase:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
@@ -89,7 +76,7 @@ class ControlBase:
|
|||||||
self.compression_ratio = 8
|
self.compression_ratio = 8
|
||||||
self.upscale_algorithm = 'nearest-exact'
|
self.upscale_algorithm = 'nearest-exact'
|
||||||
self.extra_args = {}
|
self.extra_args = {}
|
||||||
self.previous_controlnet: Union[ControlBase, None] = None
|
self.previous_controlnet = None
|
||||||
self.extra_conds = []
|
self.extra_conds = []
|
||||||
self.strength_type = StrengthType.CONSTANT
|
self.strength_type = StrengthType.CONSTANT
|
||||||
self.concat_mask = False
|
self.concat_mask = False
|
||||||
@@ -97,7 +84,6 @@ class ControlBase:
|
|||||||
self.extra_concat = None
|
self.extra_concat = None
|
||||||
self.extra_hooks: HookGroup = None
|
self.extra_hooks: HookGroup = None
|
||||||
self.preprocess_image = lambda a: a
|
self.preprocess_image = lambda a: a
|
||||||
self.multigpu_clones: dict[torch.device, ControlBase] = {}
|
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
@@ -124,38 +110,17 @@ class ControlBase:
|
|||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
self.previous_controlnet.cleanup()
|
self.previous_controlnet.cleanup()
|
||||||
for device_cnet in self.multigpu_clones.values():
|
|
||||||
with ControlIsolation(device_cnet):
|
|
||||||
device_cnet.cleanup()
|
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.extra_concat = None
|
self.extra_concat = None
|
||||||
self.timestep_range = None
|
self.timestep_range = None
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
out = []
|
out = []
|
||||||
for device_cnet in self.multigpu_clones.values():
|
|
||||||
out += device_cnet.get_models_only_self()
|
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
out += self.previous_controlnet.get_models()
|
out += self.previous_controlnet.get_models()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def get_models_only_self(self):
|
|
||||||
'Calls get_models, but temporarily sets previous_controlnet to None.'
|
|
||||||
with ControlIsolation(self):
|
|
||||||
return self.get_models()
|
|
||||||
|
|
||||||
def get_instance_for_device(self, device):
|
|
||||||
'Returns instance of this Control object intended for selected device.'
|
|
||||||
return self.multigpu_clones.get(device, self)
|
|
||||||
|
|
||||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
|
||||||
'''
|
|
||||||
Create deep clone of Control object where model(s) is set to other devices.
|
|
||||||
|
|
||||||
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
|
|
||||||
'''
|
|
||||||
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
|
|
||||||
|
|
||||||
def get_extra_hooks(self):
|
def get_extra_hooks(self):
|
||||||
out = []
|
out = []
|
||||||
if self.extra_hooks is not None:
|
if self.extra_hooks is not None:
|
||||||
@@ -164,7 +129,7 @@ class ControlBase:
|
|||||||
out += self.previous_controlnet.get_extra_hooks()
|
out += self.previous_controlnet.get_extra_hooks()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def copy_to(self, c: ControlBase):
|
def copy_to(self, c):
|
||||||
c.cond_hint_original = self.cond_hint_original
|
c.cond_hint_original = self.cond_hint_original
|
||||||
c.strength = self.strength
|
c.strength = self.strength
|
||||||
c.timestep_percent_range = self.timestep_percent_range
|
c.timestep_percent_range = self.timestep_percent_range
|
||||||
@@ -315,14 +280,6 @@ class ControlNet(ControlBase):
|
|||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
|
||||||
c = self.copy()
|
|
||||||
c.control_model = copy.deepcopy(c.control_model)
|
|
||||||
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
|
||||||
if autoregister:
|
|
||||||
self.multigpu_clones[load_device] = c
|
|
||||||
return c
|
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
out = super().get_models()
|
out = super().get_models()
|
||||||
out.append(self.control_model_wrapped)
|
out.append(self.control_model_wrapped)
|
||||||
@@ -848,14 +805,6 @@ class T2IAdapter(ControlBase):
|
|||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
|
||||||
c = self.copy()
|
|
||||||
c.t2i_model = copy.deepcopy(c.t2i_model)
|
|
||||||
c.device = load_device
|
|
||||||
if autoregister:
|
|
||||||
self.multigpu_clones[load_device] = c
|
|
||||||
return c
|
|
||||||
|
|
||||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||||
compression_ratio = 8
|
compression_ratio = 8
|
||||||
upscale_algorithm = 'nearest-exact'
|
upscale_algorithm = 'nearest-exact'
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import logging
|
import logging
|
||||||
@@ -27,10 +26,6 @@ import platform
|
|||||||
import weakref
|
import weakref
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from comfy.model_patcher import ModelPatcher
|
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||||
@@ -176,25 +171,6 @@ def get_torch_device():
|
|||||||
else:
|
else:
|
||||||
return torch.device(torch.cuda.current_device())
|
return torch.device(torch.cuda.current_device())
|
||||||
|
|
||||||
def get_all_torch_devices(exclude_current=False):
|
|
||||||
global cpu_state
|
|
||||||
devices = []
|
|
||||||
if cpu_state == CPUState.GPU:
|
|
||||||
if is_nvidia():
|
|
||||||
for i in range(torch.cuda.device_count()):
|
|
||||||
devices.append(torch.device(i))
|
|
||||||
elif is_intel_xpu():
|
|
||||||
for i in range(torch.xpu.device_count()):
|
|
||||||
devices.append(torch.device(i))
|
|
||||||
elif is_ascend_npu():
|
|
||||||
for i in range(torch.npu.device_count()):
|
|
||||||
devices.append(torch.device(i))
|
|
||||||
else:
|
|
||||||
devices.append(get_torch_device())
|
|
||||||
if exclude_current:
|
|
||||||
devices.remove(get_torch_device())
|
|
||||||
return devices
|
|
||||||
|
|
||||||
def get_total_memory(dev=None, torch_total_too=False):
|
def get_total_memory(dev=None, torch_total_too=False):
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
if dev is None:
|
if dev is None:
|
||||||
@@ -411,13 +387,9 @@ try:
|
|||||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||||
except:
|
except:
|
||||||
logging.warning("Could not pick default device.")
|
logging.warning("Could not pick default device.")
|
||||||
try:
|
|
||||||
for device in get_all_torch_devices(exclude_current=True):
|
|
||||||
logging.info("Device: {}".format(get_torch_device_name(device)))
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
current_loaded_models: list[LoadedModel] = []
|
|
||||||
|
current_loaded_models = []
|
||||||
|
|
||||||
def module_size(module):
|
def module_size(module):
|
||||||
module_mem = 0
|
module_mem = 0
|
||||||
@@ -428,7 +400,7 @@ def module_size(module):
|
|||||||
return module_mem
|
return module_mem
|
||||||
|
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
def __init__(self, model: ModelPatcher):
|
def __init__(self, model):
|
||||||
self._set_model(model)
|
self._set_model(model)
|
||||||
self.device = model.load_device
|
self.device = model.load_device
|
||||||
self.real_model = None
|
self.real_model = None
|
||||||
@@ -436,7 +408,7 @@ class LoadedModel:
|
|||||||
self.model_finalizer = None
|
self.model_finalizer = None
|
||||||
self._patcher_finalizer = None
|
self._patcher_finalizer = None
|
||||||
|
|
||||||
def _set_model(self, model: ModelPatcher):
|
def _set_model(self, model):
|
||||||
self._model = weakref.ref(model)
|
self._model = weakref.ref(model)
|
||||||
if model.parent is not None:
|
if model.parent is not None:
|
||||||
self._parent_model = weakref.ref(model.parent)
|
self._parent_model = weakref.ref(model.parent)
|
||||||
@@ -1328,34 +1300,8 @@ def soft_empty_cache(force=False):
|
|||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
def unload_all_models():
|
def unload_all_models():
|
||||||
for device in get_all_torch_devices():
|
free_memory(1e30, get_torch_device())
|
||||||
free_memory(1e30, device)
|
|
||||||
|
|
||||||
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
|
|
||||||
'Unload only model and its clones - primarily for multigpu cloning purposes.'
|
|
||||||
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
|
|
||||||
additional_models = []
|
|
||||||
if unload_additional_models:
|
|
||||||
additional_models = model.get_nested_additional_models()
|
|
||||||
keep_loaded = []
|
|
||||||
for loaded_model in initial_keep_loaded:
|
|
||||||
if loaded_model.model is not None:
|
|
||||||
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
|
||||||
continue
|
|
||||||
# check additional models if they are a match
|
|
||||||
skip = False
|
|
||||||
for add_model in additional_models:
|
|
||||||
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
|
||||||
skip = True
|
|
||||||
break
|
|
||||||
if skip:
|
|
||||||
continue
|
|
||||||
keep_loaded.append(loaded_model)
|
|
||||||
if not all_devices:
|
|
||||||
free_memory(1e30, get_torch_device(), keep_loaded)
|
|
||||||
else:
|
|
||||||
for device in get_all_torch_devices():
|
|
||||||
free_memory(1e30, device, keep_loaded)
|
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
#TODO: might be cleaner to put this somewhere else
|
||||||
import threading
|
import threading
|
||||||
|
|||||||
@@ -84,15 +84,12 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
|||||||
def create_model_options_clone(orig_model_options: dict):
|
def create_model_options_clone(orig_model_options: dict):
|
||||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||||
|
|
||||||
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
|
def create_hook_patches_clone(orig_hook_patches):
|
||||||
new_hook_patches = {}
|
new_hook_patches = {}
|
||||||
for hook_ref in orig_hook_patches:
|
for hook_ref in orig_hook_patches:
|
||||||
new_hook_patches[hook_ref] = {}
|
new_hook_patches[hook_ref] = {}
|
||||||
for k in orig_hook_patches[hook_ref]:
|
for k in orig_hook_patches[hook_ref]:
|
||||||
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
||||||
if copy_tuples:
|
|
||||||
for i in range(len(new_hook_patches[hook_ref][k])):
|
|
||||||
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
|
|
||||||
return new_hook_patches
|
return new_hook_patches
|
||||||
|
|
||||||
def wipe_lowvram_weight(m):
|
def wipe_lowvram_weight(m):
|
||||||
@@ -243,9 +240,6 @@ class ModelPatcher:
|
|||||||
self.is_clip = False
|
self.is_clip = False
|
||||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||||
|
|
||||||
self.is_multigpu_base_clone = False
|
|
||||||
self.clone_base_uuid = uuid.uuid4()
|
|
||||||
|
|
||||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
@@ -324,90 +318,16 @@ class ModelPatcher:
|
|||||||
n.is_clip = self.is_clip
|
n.is_clip = self.is_clip
|
||||||
n.hook_mode = self.hook_mode
|
n.hook_mode = self.hook_mode
|
||||||
|
|
||||||
n.is_multigpu_base_clone = self.is_multigpu_base_clone
|
|
||||||
n.clone_base_uuid = self.clone_base_uuid
|
|
||||||
|
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||||
callback(self, n)
|
callback(self, n)
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
|
|
||||||
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
|
|
||||||
comfy.model_management.unload_model_and_clones(self)
|
|
||||||
n = self.clone()
|
|
||||||
# set load device, if present
|
|
||||||
if new_load_device is not None:
|
|
||||||
n.load_device = new_load_device
|
|
||||||
# unlike for normal clone, backup dicts that shared same ref should not;
|
|
||||||
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
|
|
||||||
n.backup = copy.deepcopy(n.backup)
|
|
||||||
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
|
|
||||||
n.hook_backup = copy.deepcopy(n.hook_backup)
|
|
||||||
n.model = copy.deepcopy(n.model)
|
|
||||||
# multigpu clone should not have multigpu additional_models entry
|
|
||||||
n.remove_additional_models("multigpu")
|
|
||||||
# multigpu_clone all stored additional_models; make sure circular references are properly handled
|
|
||||||
if models_cache is None:
|
|
||||||
models_cache = {}
|
|
||||||
for key, model_list in n.additional_models.items():
|
|
||||||
for i in range(len(model_list)):
|
|
||||||
add_model = n.additional_models[key][i]
|
|
||||||
if add_model.clone_base_uuid not in models_cache:
|
|
||||||
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
|
|
||||||
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
|
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
|
|
||||||
callback(self, n)
|
|
||||||
return n
|
|
||||||
|
|
||||||
def match_multigpu_clones(self):
|
|
||||||
multigpu_models = self.get_additional_models_with_key("multigpu")
|
|
||||||
if len(multigpu_models) > 0:
|
|
||||||
new_multigpu_models = []
|
|
||||||
for mm in multigpu_models:
|
|
||||||
# clone main model, but bring over relevant props from existing multigpu clone
|
|
||||||
n = self.clone()
|
|
||||||
n.load_device = mm.load_device
|
|
||||||
n.backup = mm.backup
|
|
||||||
n.object_patches_backup = mm.object_patches_backup
|
|
||||||
n.hook_backup = mm.hook_backup
|
|
||||||
n.model = mm.model
|
|
||||||
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
|
|
||||||
n.remove_additional_models("multigpu")
|
|
||||||
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
|
|
||||||
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
|
|
||||||
# figure out which additional models are not present in multigpu clone
|
|
||||||
models_cache = {}
|
|
||||||
for mm_add_model in mm.get_additional_models():
|
|
||||||
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
|
|
||||||
remove_models_uuids = set(list(models_cache.keys()))
|
|
||||||
for key, model_list in orig_additional_models.items():
|
|
||||||
for orig_add_model in model_list:
|
|
||||||
if orig_add_model.clone_base_uuid not in models_cache:
|
|
||||||
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
|
|
||||||
existing_list = n.get_additional_models_with_key(key)
|
|
||||||
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
|
|
||||||
n.set_additional_models(key, existing_list)
|
|
||||||
if orig_add_model.clone_base_uuid in remove_models_uuids:
|
|
||||||
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
|
|
||||||
# remove duplicate additional models
|
|
||||||
for key, model_list in n.additional_models.items():
|
|
||||||
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
|
|
||||||
n.set_additional_models(key, new_model_list)
|
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
|
|
||||||
callback(self, n)
|
|
||||||
new_multigpu_models.append(n)
|
|
||||||
self.set_additional_models("multigpu", new_multigpu_models)
|
|
||||||
|
|
||||||
def is_clone(self, other):
|
def is_clone(self, other):
|
||||||
if hasattr(other, 'model') and self.model is other.model:
|
if hasattr(other, 'model') and self.model is other.model:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
|
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
||||||
if allow_multigpu:
|
|
||||||
if self.clone_base_uuid != clone.clone_base_uuid:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
if not self.is_clone(clone):
|
if not self.is_clone(clone):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -1009,7 +929,7 @@ class ModelPatcher:
|
|||||||
return self.additional_models.get(key, [])
|
return self.additional_models.get(key, [])
|
||||||
|
|
||||||
def get_additional_models(self):
|
def get_additional_models(self):
|
||||||
all_models: list[ModelPatcher] = []
|
all_models = []
|
||||||
for models in self.additional_models.values():
|
for models in self.additional_models.values():
|
||||||
all_models.extend(models)
|
all_models.extend(models)
|
||||||
return all_models
|
return all_models
|
||||||
@@ -1063,13 +983,9 @@ class ModelPatcher:
|
|||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
||||||
callback(self)
|
callback(self)
|
||||||
|
|
||||||
def prepare_state(self, timestep, model_options, ignore_multigpu=False):
|
def prepare_state(self, timestep):
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
||||||
callback(self, timestep, model_options, ignore_multigpu)
|
callback(self, timestep)
|
||||||
if not ignore_multigpu and "multigpu_clones" in model_options:
|
|
||||||
for p in model_options["multigpu_clones"].values():
|
|
||||||
p: ModelPatcher
|
|
||||||
p.prepare_state(timestep, model_options, ignore_multigpu=True)
|
|
||||||
|
|
||||||
def restore_hook_patches(self):
|
def restore_hook_patches(self):
|
||||||
if self.hook_patches_backup is not None:
|
if self.hook_patches_backup is not None:
|
||||||
@@ -1082,18 +998,12 @@ class ModelPatcher:
|
|||||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
||||||
curr_t = t[0]
|
curr_t = t[0]
|
||||||
reset_current_hooks = False
|
reset_current_hooks = False
|
||||||
multigpu_kf_changed_cache = None
|
|
||||||
transformer_options = model_options.get("transformer_options", {})
|
transformer_options = model_options.get("transformer_options", {})
|
||||||
for hook in hook_group.hooks:
|
for hook in hook_group.hooks:
|
||||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
||||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||||
# this will cause the weights to be recalculated when sampling
|
# this will cause the weights to be recalculated when sampling
|
||||||
if changed:
|
if changed:
|
||||||
# cache changed for multigpu usage
|
|
||||||
if "multigpu_clones" in model_options:
|
|
||||||
if multigpu_kf_changed_cache is None:
|
|
||||||
multigpu_kf_changed_cache = []
|
|
||||||
multigpu_kf_changed_cache.append(hook)
|
|
||||||
# reset current_hooks if contains hook that changed
|
# reset current_hooks if contains hook that changed
|
||||||
if self.current_hooks is not None:
|
if self.current_hooks is not None:
|
||||||
for current_hook in self.current_hooks.hooks:
|
for current_hook in self.current_hooks.hooks:
|
||||||
@@ -1105,28 +1015,6 @@ class ModelPatcher:
|
|||||||
self.cached_hook_patches.pop(cached_group)
|
self.cached_hook_patches.pop(cached_group)
|
||||||
if reset_current_hooks:
|
if reset_current_hooks:
|
||||||
self.patch_hooks(None)
|
self.patch_hooks(None)
|
||||||
if "multigpu_clones" in model_options:
|
|
||||||
for p in model_options["multigpu_clones"].values():
|
|
||||||
p: ModelPatcher
|
|
||||||
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
|
|
||||||
|
|
||||||
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
|
|
||||||
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
|
|
||||||
if kf_changed_cache is None:
|
|
||||||
return
|
|
||||||
reset_current_hooks = False
|
|
||||||
# reset current_hooks if contains hook that changed
|
|
||||||
for hook in kf_changed_cache:
|
|
||||||
if self.current_hooks is not None:
|
|
||||||
for current_hook in self.current_hooks.hooks:
|
|
||||||
if current_hook == hook:
|
|
||||||
reset_current_hooks = True
|
|
||||||
break
|
|
||||||
for cached_group in list(self.cached_hook_patches.keys()):
|
|
||||||
if cached_group.contains(hook):
|
|
||||||
self.cached_hook_patches.pop(cached_group)
|
|
||||||
if reset_current_hooks:
|
|
||||||
self.patch_hooks(None)
|
|
||||||
|
|
||||||
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||||
registered: comfy.hooks.HookGroup = None):
|
registered: comfy.hooks.HookGroup = None):
|
||||||
|
|||||||
@@ -1,167 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import torch
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from collections import namedtuple
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from comfy.model_patcher import ModelPatcher
|
|
||||||
import comfy.utils
|
|
||||||
import comfy.patcher_extension
|
|
||||||
import comfy.model_management
|
|
||||||
|
|
||||||
|
|
||||||
class GPUOptions:
|
|
||||||
def __init__(self, device_index: int, relative_speed: float):
|
|
||||||
self.device_index = device_index
|
|
||||||
self.relative_speed = relative_speed
|
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
return GPUOptions(self.device_index, self.relative_speed)
|
|
||||||
|
|
||||||
def create_dict(self):
|
|
||||||
return {
|
|
||||||
"relative_speed": self.relative_speed
|
|
||||||
}
|
|
||||||
|
|
||||||
class GPUOptionsGroup:
|
|
||||||
def __init__(self):
|
|
||||||
self.options: dict[int, GPUOptions] = {}
|
|
||||||
|
|
||||||
def add(self, info: GPUOptions):
|
|
||||||
self.options[info.device_index] = info
|
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
c = GPUOptionsGroup()
|
|
||||||
for opt in self.options.values():
|
|
||||||
c.add(opt)
|
|
||||||
return c
|
|
||||||
|
|
||||||
def register(self, model: ModelPatcher):
|
|
||||||
opts_dict = {}
|
|
||||||
# get devices that are valid for this model
|
|
||||||
devices: list[torch.device] = [model.load_device]
|
|
||||||
for extra_model in model.get_additional_models_with_key("multigpu"):
|
|
||||||
extra_model: ModelPatcher
|
|
||||||
devices.append(extra_model.load_device)
|
|
||||||
# create dictionary with actual device mapped to its GPUOptions
|
|
||||||
device_opts_list: list[GPUOptions] = []
|
|
||||||
for device in devices:
|
|
||||||
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
|
|
||||||
opts_dict[device] = device_opts.create_dict()
|
|
||||||
device_opts_list.append(device_opts)
|
|
||||||
# make relative_speed relative to 1.0
|
|
||||||
min_speed = min([x.relative_speed for x in device_opts_list])
|
|
||||||
for value in opts_dict.values():
|
|
||||||
value['relative_speed'] /= min_speed
|
|
||||||
model.model_options['multigpu_options'] = opts_dict
|
|
||||||
|
|
||||||
|
|
||||||
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
|
|
||||||
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
|
|
||||||
model = model.clone()
|
|
||||||
# check if multigpu is already prepared - get the load devices from them if possible to exclude
|
|
||||||
skip_devices = set()
|
|
||||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
|
||||||
if len(multigpu_models) > 0:
|
|
||||||
for mm in multigpu_models:
|
|
||||||
skip_devices.add(mm.load_device)
|
|
||||||
skip_devices = list(skip_devices)
|
|
||||||
|
|
||||||
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
|
||||||
limit_extra_devices = full_extra_devices[:max_gpus-1]
|
|
||||||
extra_devices = limit_extra_devices.copy()
|
|
||||||
# exclude skipped devices
|
|
||||||
for skip in skip_devices:
|
|
||||||
if skip in extra_devices:
|
|
||||||
extra_devices.remove(skip)
|
|
||||||
# create new deepclones
|
|
||||||
if len(extra_devices) > 0:
|
|
||||||
for device in extra_devices:
|
|
||||||
device_patcher = None
|
|
||||||
if reuse_loaded:
|
|
||||||
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
|
|
||||||
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
|
|
||||||
for lm in loaded_models:
|
|
||||||
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
|
|
||||||
device_patcher = lm.clone()
|
|
||||||
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
|
|
||||||
break
|
|
||||||
if device_patcher is None:
|
|
||||||
device_patcher = model.deepclone_multigpu(new_load_device=device)
|
|
||||||
device_patcher.is_multigpu_base_clone = True
|
|
||||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
|
||||||
multigpu_models.append(device_patcher)
|
|
||||||
model.set_additional_models("multigpu", multigpu_models)
|
|
||||||
model.match_multigpu_clones()
|
|
||||||
if gpu_options is None:
|
|
||||||
gpu_options = GPUOptionsGroup()
|
|
||||||
gpu_options.register(model)
|
|
||||||
else:
|
|
||||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
|
|
||||||
# TODO: only keep model clones that don't go 'past' the intended max_gpu count
|
|
||||||
# multigpu_models = model.get_additional_models_with_key("multigpu")
|
|
||||||
# new_multigpu_models = []
|
|
||||||
# for m in multigpu_models:
|
|
||||||
# if m.load_device in limit_extra_devices:
|
|
||||||
# new_multigpu_models.append(m)
|
|
||||||
# model.set_additional_models("multigpu", new_multigpu_models)
|
|
||||||
# persist skip_devices for use in sampling code
|
|
||||||
# if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options:
|
|
||||||
# model.model_options["multigpu_skip_devices"] = skip_devices
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
|
||||||
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
|
||||||
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
|
||||||
opts_dict = model_options['multigpu_options']
|
|
||||||
devices = list(model_options['multigpu_clones'].keys())
|
|
||||||
speed_per_device = []
|
|
||||||
work_per_device = []
|
|
||||||
# get sum of each device's relative_speed
|
|
||||||
total_speed = 0.0
|
|
||||||
for opts in opts_dict.values():
|
|
||||||
total_speed += opts['relative_speed']
|
|
||||||
# get relative work for each device;
|
|
||||||
# obtained by w = (W*r)/R
|
|
||||||
for device in devices:
|
|
||||||
relative_speed = opts_dict[device]['relative_speed']
|
|
||||||
relative_work = (total_work*relative_speed) / total_speed
|
|
||||||
speed_per_device.append(relative_speed)
|
|
||||||
work_per_device.append(relative_work)
|
|
||||||
# relative work must be expressed in whole numbers, but likely is a decimal;
|
|
||||||
# perform rounding while maintaining total sum equal to total work (sum of relative works)
|
|
||||||
work_per_device = round_preserved(work_per_device)
|
|
||||||
dict_work_per_device = {}
|
|
||||||
for device, relative_work in zip(devices, work_per_device):
|
|
||||||
dict_work_per_device[device] = relative_work
|
|
||||||
if not return_idle_time:
|
|
||||||
return LoadBalance(dict_work_per_device, None)
|
|
||||||
# divide relative work by relative speed to get estimated completion time of said work by each device;
|
|
||||||
# time here is relative and does not correspond to real-world units
|
|
||||||
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
|
|
||||||
# calculate relative time spent by the devices waiting on each other after their work is completed
|
|
||||||
idle_time = abs(min(completion_time) - max(completion_time))
|
|
||||||
# if need to compare work idle time, need to normalize to a common total work
|
|
||||||
if work_normalized:
|
|
||||||
idle_time *= (work_normalized/total_work)
|
|
||||||
|
|
||||||
return LoadBalance(dict_work_per_device, idle_time)
|
|
||||||
|
|
||||||
def round_preserved(values: list[float]):
|
|
||||||
'Round all values in a list, preserving the combined sum of values.'
|
|
||||||
# get floor of values; casting to int does it too
|
|
||||||
floored = [int(x) for x in values]
|
|
||||||
total_floored = sum(floored)
|
|
||||||
# get remainder to distribute
|
|
||||||
remainder = round(sum(values)) - total_floored
|
|
||||||
# pair values with fractional portions
|
|
||||||
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
|
|
||||||
# sort by fractional part in descending order
|
|
||||||
fractional.sort(key=lambda x: x[1], reverse=True)
|
|
||||||
# distribute the remainder
|
|
||||||
for i in range(remainder):
|
|
||||||
index = fractional[i][0]
|
|
||||||
floored[index] += 1
|
|
||||||
return floored
|
|
||||||
@@ -3,8 +3,6 @@ from typing import Callable
|
|||||||
|
|
||||||
class CallbacksMP:
|
class CallbacksMP:
|
||||||
ON_CLONE = "on_clone"
|
ON_CLONE = "on_clone"
|
||||||
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
|
|
||||||
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
|
|
||||||
ON_LOAD = "on_load_after"
|
ON_LOAD = "on_load_after"
|
||||||
ON_DETACH = "on_detach_after"
|
ON_DETACH = "on_detach_after"
|
||||||
ON_CLEANUP = "on_cleanup"
|
ON_CLEANUP = "on_cleanup"
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import torch
|
|
||||||
import uuid
|
import uuid
|
||||||
import math
|
import math
|
||||||
import collections
|
import collections
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.conds
|
import comfy.conds
|
||||||
import comfy.model_patcher
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@@ -108,47 +106,6 @@ def cleanup_additional_models(models):
|
|||||||
if hasattr(m, 'cleanup'):
|
if hasattr(m, 'cleanup'):
|
||||||
m.cleanup()
|
m.cleanup()
|
||||||
|
|
||||||
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
|
|
||||||
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
|
|
||||||
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
|
|
||||||
if len(multigpu_models) == 0:
|
|
||||||
return
|
|
||||||
extra_devices = [x.load_device for x in multigpu_models]
|
|
||||||
# handle controlnets
|
|
||||||
controlnets: set[ControlBase] = set()
|
|
||||||
for k in conds:
|
|
||||||
for kk in conds[k]:
|
|
||||||
if 'control' in kk:
|
|
||||||
controlnets.add(kk['control'])
|
|
||||||
if len(controlnets) > 0:
|
|
||||||
# first, unload all controlnet clones
|
|
||||||
for cnet in list(controlnets):
|
|
||||||
cnet_models = cnet.get_models()
|
|
||||||
for cm in cnet_models:
|
|
||||||
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
|
|
||||||
|
|
||||||
# next, make sure each controlnet has a deepclone for all relevant devices
|
|
||||||
for cnet in controlnets:
|
|
||||||
curr_cnet = cnet
|
|
||||||
while curr_cnet is not None:
|
|
||||||
for device in extra_devices:
|
|
||||||
if device not in curr_cnet.multigpu_clones:
|
|
||||||
curr_cnet.deepclone_multigpu(device, autoregister=True)
|
|
||||||
curr_cnet = curr_cnet.previous_controlnet
|
|
||||||
# since all device clones are now present, recreate the linked list for cloned cnets per device
|
|
||||||
for cnet in controlnets:
|
|
||||||
curr_cnet = cnet
|
|
||||||
while curr_cnet is not None:
|
|
||||||
prev_cnet = curr_cnet.previous_controlnet
|
|
||||||
for device in extra_devices:
|
|
||||||
device_cnet = curr_cnet.get_instance_for_device(device)
|
|
||||||
prev_device_cnet = None
|
|
||||||
if prev_cnet is not None:
|
|
||||||
prev_device_cnet = prev_cnet.get_instance_for_device(device)
|
|
||||||
device_cnet.set_previous_controlnet(prev_device_cnet)
|
|
||||||
curr_cnet = prev_cnet
|
|
||||||
# potentially handle gligen - since not widely used, ignored for now
|
|
||||||
|
|
||||||
def estimate_memory(model, noise_shape, conds):
|
def estimate_memory(model, noise_shape, conds):
|
||||||
cond_shapes = collections.defaultdict(list)
|
cond_shapes = collections.defaultdict(list)
|
||||||
cond_shapes_min = {}
|
cond_shapes_min = {}
|
||||||
@@ -173,8 +130,7 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
|
|||||||
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
||||||
|
|
||||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||||
model.match_multigpu_clones()
|
real_model: BaseModel = None
|
||||||
preprocess_multigpu_conds(conds, model, model_options)
|
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
models += get_additional_models_from_model_options(model_options)
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
@@ -193,7 +149,7 @@ def cleanup_models(conds, models):
|
|||||||
|
|
||||||
cleanup_additional_models(set(control_cleanup))
|
cleanup_additional_models(set(control_cleanup))
|
||||||
|
|
||||||
def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||||
'''
|
'''
|
||||||
Registers hooks from conds.
|
Registers hooks from conds.
|
||||||
'''
|
'''
|
||||||
@@ -226,18 +182,3 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
|||||||
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||||
copy_dict1=False)
|
copy_dict1=False)
|
||||||
return to_load_options
|
return to_load_options
|
||||||
|
|
||||||
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
|
|
||||||
'''
|
|
||||||
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
|
|
||||||
'''
|
|
||||||
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
|
|
||||||
if len(multigpu_patchers) > 0:
|
|
||||||
multigpu_dict: dict[torch.device, ModelPatcher] = {}
|
|
||||||
multigpu_dict[model_patcher.load_device] = model_patcher
|
|
||||||
for x in multigpu_patchers:
|
|
||||||
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
|
|
||||||
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
|
|
||||||
multigpu_dict[x.load_device] = x
|
|
||||||
model_options["multigpu_clones"] = multigpu_dict
|
|
||||||
return multigpu_patchers
|
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import comfy.model_management
|
|
||||||
from .k_diffusion import sampling as k_diffusion_sampling
|
from .k_diffusion import sampling as k_diffusion_sampling
|
||||||
from .extra_samplers import uni_pc
|
from .extra_samplers import uni_pc
|
||||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||||
@@ -20,7 +18,6 @@ import comfy.patcher_extension
|
|||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
import threading
|
|
||||||
|
|
||||||
|
|
||||||
def add_area_dims(area, num_dims):
|
def add_area_dims(area, num_dims):
|
||||||
@@ -143,7 +140,7 @@ def can_concat_cond(c1, c2):
|
|||||||
|
|
||||||
return cond_equal_size(c1.conditioning, c2.conditioning)
|
return cond_equal_size(c1.conditioning, c2.conditioning)
|
||||||
|
|
||||||
def cond_cat(c_list, device=None):
|
def cond_cat(c_list):
|
||||||
temp = {}
|
temp = {}
|
||||||
for x in c_list:
|
for x in c_list:
|
||||||
for k in x:
|
for k in x:
|
||||||
@@ -155,8 +152,6 @@ def cond_cat(c_list, device=None):
|
|||||||
for k in temp:
|
for k in temp:
|
||||||
conds = temp[k]
|
conds = temp[k]
|
||||||
out[k] = conds[0].concat(conds[1:])
|
out[k] = conds[0].concat(conds[1:])
|
||||||
if device is not None and hasattr(out[k], 'to'):
|
|
||||||
out[k] = out[k].to(device)
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -210,9 +205,7 @@ def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Ten
|
|||||||
)
|
)
|
||||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||||
|
|
||||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||||
if 'multigpu_clones' in model_options:
|
|
||||||
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
|
|
||||||
out_conds = []
|
out_conds = []
|
||||||
out_counts = []
|
out_counts = []
|
||||||
# separate conds by matching hooks
|
# separate conds by matching hooks
|
||||||
@@ -244,7 +237,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
|||||||
if has_default_conds:
|
if has_default_conds:
|
||||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||||
|
|
||||||
model.current_patcher.prepare_state(timestep, model_options)
|
model.current_patcher.prepare_state(timestep)
|
||||||
|
|
||||||
# run every hooked_to_run separately
|
# run every hooked_to_run separately
|
||||||
for hooks, to_run in hooked_to_run.items():
|
for hooks, to_run in hooked_to_run.items():
|
||||||
@@ -352,190 +345,6 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
|||||||
|
|
||||||
return out_conds
|
return out_conds
|
||||||
|
|
||||||
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
|
||||||
out_conds = []
|
|
||||||
out_counts = []
|
|
||||||
# separate conds by matching hooks
|
|
||||||
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
|
||||||
default_conds = []
|
|
||||||
has_default_conds = False
|
|
||||||
|
|
||||||
output_device = x_in.device
|
|
||||||
|
|
||||||
for i in range(len(conds)):
|
|
||||||
out_conds.append(torch.zeros_like(x_in))
|
|
||||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
|
||||||
|
|
||||||
cond = conds[i]
|
|
||||||
default_c = []
|
|
||||||
if cond is not None:
|
|
||||||
for x in cond:
|
|
||||||
if 'default' in x:
|
|
||||||
default_c.append(x)
|
|
||||||
has_default_conds = True
|
|
||||||
continue
|
|
||||||
p = get_area_and_mult(x, x_in, timestep)
|
|
||||||
if p is None:
|
|
||||||
continue
|
|
||||||
if p.hooks is not None:
|
|
||||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
|
||||||
hooked_to_run.setdefault(p.hooks, list())
|
|
||||||
hooked_to_run[p.hooks] += [(p, i)]
|
|
||||||
default_conds.append(default_c)
|
|
||||||
|
|
||||||
if has_default_conds:
|
|
||||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
|
||||||
|
|
||||||
model.current_patcher.prepare_state(timestep, model_options)
|
|
||||||
|
|
||||||
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
|
|
||||||
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
|
|
||||||
|
|
||||||
total_conds = 0
|
|
||||||
for to_run in hooked_to_run.values():
|
|
||||||
total_conds += len(to_run)
|
|
||||||
conds_per_device = max(1, math.ceil(total_conds//len(devices)))
|
|
||||||
index_device = 0
|
|
||||||
current_device = devices[index_device]
|
|
||||||
# run every hooked_to_run separately
|
|
||||||
for hooks, to_run in hooked_to_run.items():
|
|
||||||
while len(to_run) > 0:
|
|
||||||
current_device = devices[index_device % len(devices)]
|
|
||||||
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
|
|
||||||
# keep track of conds currently scheduled onto this device
|
|
||||||
batched_to_run_length = 0
|
|
||||||
for btr in batched_to_run:
|
|
||||||
batched_to_run_length += len(btr[1])
|
|
||||||
|
|
||||||
first = to_run[0]
|
|
||||||
first_shape = first[0][0].shape
|
|
||||||
to_batch_temp = []
|
|
||||||
# make sure not over conds_per_device limit when creating temp batch
|
|
||||||
for x in range(len(to_run)):
|
|
||||||
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length):
|
|
||||||
to_batch_temp += [x]
|
|
||||||
|
|
||||||
to_batch_temp.reverse()
|
|
||||||
to_batch = to_batch_temp[:1]
|
|
||||||
|
|
||||||
free_memory = model_management.get_free_memory(current_device)
|
|
||||||
for i in range(1, len(to_batch_temp) + 1):
|
|
||||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
|
||||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
|
||||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
|
||||||
to_batch = batch_amount
|
|
||||||
break
|
|
||||||
conds_to_batch = []
|
|
||||||
for x in to_batch:
|
|
||||||
conds_to_batch.append(to_run.pop(x))
|
|
||||||
batched_to_run_length += len(conds_to_batch)
|
|
||||||
|
|
||||||
batched_to_run.append((hooks, conds_to_batch))
|
|
||||||
if batched_to_run_length >= conds_per_device:
|
|
||||||
index_device += 1
|
|
||||||
|
|
||||||
thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond'])
|
|
||||||
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
|
||||||
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
|
||||||
# run every hooked_to_run separately
|
|
||||||
with torch.no_grad():
|
|
||||||
for hooks, to_batch in batch_tuple:
|
|
||||||
input_x = []
|
|
||||||
mult = []
|
|
||||||
c = []
|
|
||||||
cond_or_uncond = []
|
|
||||||
uuids = []
|
|
||||||
area = []
|
|
||||||
control: ControlBase = None
|
|
||||||
patches = None
|
|
||||||
for x in to_batch:
|
|
||||||
o = x
|
|
||||||
p = o[0]
|
|
||||||
input_x.append(p.input_x)
|
|
||||||
mult.append(p.mult)
|
|
||||||
c.append(p.conditioning)
|
|
||||||
area.append(p.area)
|
|
||||||
cond_or_uncond.append(o[1])
|
|
||||||
uuids.append(p.uuid)
|
|
||||||
control = p.control
|
|
||||||
patches = p.patches
|
|
||||||
|
|
||||||
batch_chunks = len(cond_or_uncond)
|
|
||||||
input_x = torch.cat(input_x).to(device)
|
|
||||||
c = cond_cat(c, device=device)
|
|
||||||
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
|
|
||||||
|
|
||||||
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
|
|
||||||
if 'transformer_options' in model_options:
|
|
||||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
|
||||||
model_options['transformer_options'],
|
|
||||||
copy_dict1=False)
|
|
||||||
|
|
||||||
if patches is not None:
|
|
||||||
# TODO: replace with merge_nested_dicts function
|
|
||||||
if "patches" in transformer_options:
|
|
||||||
cur_patches = transformer_options["patches"].copy()
|
|
||||||
for p in patches:
|
|
||||||
if p in cur_patches:
|
|
||||||
cur_patches[p] = cur_patches[p] + patches[p]
|
|
||||||
else:
|
|
||||||
cur_patches[p] = patches[p]
|
|
||||||
transformer_options["patches"] = cur_patches
|
|
||||||
else:
|
|
||||||
transformer_options["patches"] = patches
|
|
||||||
|
|
||||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
|
||||||
transformer_options["uuids"] = uuids[:]
|
|
||||||
transformer_options["sigmas"] = timestep
|
|
||||||
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
|
|
||||||
transformer_options["multigpu_thread_device"] = device
|
|
||||||
|
|
||||||
cast_transformer_options(transformer_options, device=device)
|
|
||||||
c['transformer_options'] = transformer_options
|
|
||||||
|
|
||||||
if control is not None:
|
|
||||||
device_control = control.get_instance_for_device(device)
|
|
||||||
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
|
||||||
|
|
||||||
if 'model_function_wrapper' in model_options:
|
|
||||||
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
|
||||||
else:
|
|
||||||
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
|
|
||||||
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
|
|
||||||
|
|
||||||
|
|
||||||
results: list[thread_result] = []
|
|
||||||
threads: list[threading.Thread] = []
|
|
||||||
for device, batch_tuple in device_batched_hooked_to_run.items():
|
|
||||||
new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results))
|
|
||||||
threads.append(new_thread)
|
|
||||||
new_thread.start()
|
|
||||||
|
|
||||||
for thread in threads:
|
|
||||||
thread.join()
|
|
||||||
|
|
||||||
for output, mult, area, batch_chunks, cond_or_uncond in results:
|
|
||||||
for o in range(batch_chunks):
|
|
||||||
cond_index = cond_or_uncond[o]
|
|
||||||
a = area[o]
|
|
||||||
if a is None:
|
|
||||||
out_conds[cond_index] += output[o] * mult[o]
|
|
||||||
out_counts[cond_index] += mult[o]
|
|
||||||
else:
|
|
||||||
out_c = out_conds[cond_index]
|
|
||||||
out_cts = out_counts[cond_index]
|
|
||||||
dims = len(a) // 2
|
|
||||||
for i in range(dims):
|
|
||||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
|
||||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
|
||||||
out_c += output[o] * mult[o]
|
|
||||||
out_cts += mult[o]
|
|
||||||
|
|
||||||
for i in range(len(out_conds)):
|
|
||||||
out_conds[i] /= out_counts[i]
|
|
||||||
|
|
||||||
return out_conds
|
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||||
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||||
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||||
@@ -833,8 +642,6 @@ def pre_run_control(model, conds):
|
|||||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||||
if 'control' in x:
|
if 'control' in x:
|
||||||
x['control'].pre_run(model, percent_to_timestep_function)
|
x['control'].pre_run(model, percent_to_timestep_function)
|
||||||
for device_cnet in x['control'].multigpu_clones.values():
|
|
||||||
device_cnet.pre_run(model, percent_to_timestep_function)
|
|
||||||
|
|
||||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||||
cond_cnets = []
|
cond_cnets = []
|
||||||
@@ -1077,9 +884,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
|||||||
to_load_options = model_options.get("to_load_options", None)
|
to_load_options = model_options.get("to_load_options", None)
|
||||||
if to_load_options is None:
|
if to_load_options is None:
|
||||||
return
|
return
|
||||||
cast_transformer_options(to_load_options, device, dtype)
|
|
||||||
|
|
||||||
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
|
|
||||||
casts = []
|
casts = []
|
||||||
if device is not None:
|
if device is not None:
|
||||||
casts.append(device)
|
casts.append(device)
|
||||||
@@ -1088,17 +893,18 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype=
|
|||||||
# if nothing to apply, do nothing
|
# if nothing to apply, do nothing
|
||||||
if len(casts) == 0:
|
if len(casts) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# try to call .to on patches
|
# try to call .to on patches
|
||||||
if "patches" in transformer_options:
|
if "patches" in to_load_options:
|
||||||
patches = transformer_options["patches"]
|
patches = to_load_options["patches"]
|
||||||
for name in patches:
|
for name in patches:
|
||||||
patch_list = patches[name]
|
patch_list = patches[name]
|
||||||
for i in range(len(patch_list)):
|
for i in range(len(patch_list)):
|
||||||
if hasattr(patch_list[i], "to"):
|
if hasattr(patch_list[i], "to"):
|
||||||
for cast in casts:
|
for cast in casts:
|
||||||
patch_list[i] = patch_list[i].to(cast)
|
patch_list[i] = patch_list[i].to(cast)
|
||||||
if "patches_replace" in transformer_options:
|
if "patches_replace" in to_load_options:
|
||||||
patches = transformer_options["patches_replace"]
|
patches = to_load_options["patches_replace"]
|
||||||
for name in patches:
|
for name in patches:
|
||||||
patch_list = patches[name]
|
patch_list = patches[name]
|
||||||
for k in patch_list:
|
for k in patch_list:
|
||||||
@@ -1108,8 +914,8 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype=
|
|||||||
# try to call .to on any wrappers/callbacks
|
# try to call .to on any wrappers/callbacks
|
||||||
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||||
for wc_name in wrappers_and_callbacks:
|
for wc_name in wrappers_and_callbacks:
|
||||||
if wc_name in transformer_options:
|
if wc_name in to_load_options:
|
||||||
wc: dict[str, list] = transformer_options[wc_name]
|
wc: dict[str, list] = to_load_options[wc_name]
|
||||||
for wc_dict in wc.values():
|
for wc_dict in wc.values():
|
||||||
for wc_list in wc_dict.values():
|
for wc_list in wc_dict.values():
|
||||||
for i in range(len(wc_list)):
|
for i in range(len(wc_list)):
|
||||||
@@ -1117,6 +923,7 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype=
|
|||||||
for cast in casts:
|
for cast in casts:
|
||||||
wc_list[i] = wc_list[i].to(cast)
|
wc_list[i] = wc_list[i].to(cast)
|
||||||
|
|
||||||
|
|
||||||
class CFGGuider:
|
class CFGGuider:
|
||||||
def __init__(self, model_patcher: ModelPatcher):
|
def __init__(self, model_patcher: ModelPatcher):
|
||||||
self.model_patcher = model_patcher
|
self.model_patcher = model_patcher
|
||||||
@@ -1162,8 +969,6 @@ class CFGGuider:
|
|||||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
|
|
||||||
|
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
||||||
|
|
||||||
@@ -1174,13 +979,9 @@ class CFGGuider:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
self.model_patcher.pre_run()
|
self.model_patcher.pre_run()
|
||||||
for multigpu_patcher in multigpu_patchers:
|
|
||||||
multigpu_patcher.pre_run()
|
|
||||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
finally:
|
finally:
|
||||||
self.model_patcher.cleanup()
|
self.model_patcher.cleanup()
|
||||||
for multigpu_patcher in multigpu_patchers:
|
|
||||||
multigpu_patcher.cleanup()
|
|
||||||
|
|
||||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||||
del self.inner_model
|
del self.inner_model
|
||||||
|
|||||||
855
comfy_api/v3/io.py
Normal file
855
comfy_api/v3/io.py
Normal file
@@ -0,0 +1,855 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import Any, Literal
|
||||||
|
from enum import Enum
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from comfy.comfy_types.node_typing import IO
|
||||||
|
|
||||||
|
|
||||||
|
class InputBehavior(str, Enum):
|
||||||
|
required = "required"
|
||||||
|
optional = "optional"
|
||||||
|
|
||||||
|
|
||||||
|
def is_class(obj):
|
||||||
|
'''
|
||||||
|
Returns True if is a class type.
|
||||||
|
Returns False if is a class instance.
|
||||||
|
'''
|
||||||
|
return isinstance(obj, type)
|
||||||
|
|
||||||
|
|
||||||
|
class NumberDisplay(str, Enum):
|
||||||
|
number = "number"
|
||||||
|
slider = "slider"
|
||||||
|
|
||||||
|
|
||||||
|
class IO_V3:
|
||||||
|
'''
|
||||||
|
Base class for V3 Inputs and Outputs.
|
||||||
|
'''
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init_subclass__(cls, io_type: IO | str, **kwargs):
|
||||||
|
cls.io_type = io_type
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
|
class InputV3(IO_V3, io_type=None):
|
||||||
|
'''
|
||||||
|
Base class for a V3 Input.
|
||||||
|
'''
|
||||||
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None):
|
||||||
|
super().__init__()
|
||||||
|
self.id = id
|
||||||
|
self.display_name = display_name
|
||||||
|
self.behavior = behavior
|
||||||
|
self.tooltip = tooltip
|
||||||
|
self.lazy = lazy
|
||||||
|
|
||||||
|
def as_dict_V1(self):
|
||||||
|
return prune_dict({
|
||||||
|
"display_name": self.display_name,
|
||||||
|
"tooltip": self.tooltip,
|
||||||
|
"lazy": self.lazy
|
||||||
|
})
|
||||||
|
|
||||||
|
def get_io_type_V1(self):
|
||||||
|
return self.io_type
|
||||||
|
|
||||||
|
class WidgetInputV3(InputV3, io_type=None):
|
||||||
|
'''
|
||||||
|
Base class for a V3 Input with widget.
|
||||||
|
'''
|
||||||
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||||
|
default: Any=None,
|
||||||
|
socketless: bool=None, widgetType: str=None):
|
||||||
|
super().__init__(id, display_name, behavior, tooltip, lazy)
|
||||||
|
self.default = default
|
||||||
|
self.socketless = socketless
|
||||||
|
self.widgetType = widgetType
|
||||||
|
|
||||||
|
def as_dict_V1(self):
|
||||||
|
return super().as_dict_V1() | prune_dict({
|
||||||
|
"default": self.default,
|
||||||
|
"socketless": self.socketless,
|
||||||
|
"widgetType": self.widgetType,
|
||||||
|
})
|
||||||
|
|
||||||
|
def CustomType(io_type: IO | str) -> type[IO_V3]:
|
||||||
|
name = f"{io_type}_IO_V3"
|
||||||
|
return type(name, (IO_V3,), {}, io_type=io_type)
|
||||||
|
|
||||||
|
def CustomInput(id: str, io_type: IO | str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None) -> InputV3:
|
||||||
|
'''
|
||||||
|
Defines input for 'io_type'. Can be used to stand in for non-core types.
|
||||||
|
'''
|
||||||
|
input_kwargs = {
|
||||||
|
"id": id,
|
||||||
|
"display_name": display_name,
|
||||||
|
"behavior": behavior,
|
||||||
|
"tooltip": tooltip,
|
||||||
|
"lazy": lazy,
|
||||||
|
}
|
||||||
|
return type(f"{io_type}Input", (InputV3,), {}, io_type=io_type)(**input_kwargs)
|
||||||
|
|
||||||
|
def CustomOutput(id: str, io_type: IO | str, display_name: str=None, tooltip: str=None) -> OutputV3:
|
||||||
|
'''
|
||||||
|
Defines output for 'io_type'. Can be used to stand in for non-core types.
|
||||||
|
'''
|
||||||
|
input_kwargs = {
|
||||||
|
"id": id,
|
||||||
|
"display_name": display_name,
|
||||||
|
"tooltip": tooltip,
|
||||||
|
}
|
||||||
|
return type(f"{io_type}Output", (OutputV3,), {}, io_type=io_type)(**input_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class BooleanInput(WidgetInputV3, io_type=IO.BOOLEAN):
|
||||||
|
'''
|
||||||
|
Boolean input.
|
||||||
|
'''
|
||||||
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||||
|
default: bool=None, label_on: str=None, label_off: str=None,
|
||||||
|
socketless: bool=None, widgetType: str=None):
|
||||||
|
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
|
||||||
|
self.label_on = label_on
|
||||||
|
self.label_off = label_off
|
||||||
|
self.default: bool
|
||||||
|
|
||||||
|
def as_dict_V1(self):
|
||||||
|
return super().as_dict_V1() | prune_dict({
|
||||||
|
"label_on": self.label_on,
|
||||||
|
"label_off": self.label_off,
|
||||||
|
})
|
||||||
|
|
||||||
|
class IntegerInput(WidgetInputV3, io_type=IO.INT):
|
||||||
|
'''
|
||||||
|
Integer input.
|
||||||
|
'''
|
||||||
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||||
|
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
||||||
|
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
|
||||||
|
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
|
||||||
|
self.min = min
|
||||||
|
self.max = max
|
||||||
|
self.step = step
|
||||||
|
self.control_after_generate = control_after_generate
|
||||||
|
self.display_mode = display_mode
|
||||||
|
self.default: int
|
||||||
|
|
||||||
|
def as_dict_V1(self):
|
||||||
|
return super().as_dict_V1() | prune_dict({
|
||||||
|
"min": self.min,
|
||||||
|
"max": self.max,
|
||||||
|
"step": self.step,
|
||||||
|
"control_after_generate": self.control_after_generate,
|
||||||
|
"display": self.display_mode, # NOTE: in frontend, the parameter is called "display"
|
||||||
|
})
|
||||||
|
|
||||||
|
class FloatInput(WidgetInputV3, io_type=IO.FLOAT):
|
||||||
|
'''
|
||||||
|
Float input.
|
||||||
|
'''
|
||||||
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||||
|
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||||
|
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
|
||||||
|
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
|
||||||
|
self.default = default
|
||||||
|
self.min = min
|
||||||
|
self.max = max
|
||||||
|
self.step = step
|
||||||
|
self.round = round
|
||||||
|
self.display_mode = display_mode
|
||||||
|
self.default: float
|
||||||
|
|
||||||
|
def as_dict_V1(self):
|
||||||
|
return super().as_dict_V1() | prune_dict({
|
||||||
|
"min": self.min,
|
||||||
|
"max": self.max,
|
||||||
|
"step": self.step,
|
||||||
|
"round": self.round,
|
||||||
|
"display": self.display_mode, # NOTE: in frontend, the parameter is called "display"
|
||||||
|
})
|
||||||
|
|
||||||
|
class StringInput(WidgetInputV3, io_type=IO.STRING):
|
||||||
|
'''
|
||||||
|
String input.
|
||||||
|
'''
|
||||||
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||||
|
multiline=False, placeholder: str=None, default: int=None,
|
||||||
|
socketless: bool=None, widgetType: str=None):
|
||||||
|
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
|
||||||
|
self.multiline = multiline
|
||||||
|
self.placeholder = placeholder
|
||||||
|
self.default: str
|
||||||
|
|
||||||
|
def as_dict_V1(self):
|
||||||
|
return super().as_dict_V1() | prune_dict({
|
||||||
|
"multiline": self.multiline,
|
||||||
|
"placeholder": self.placeholder,
|
||||||
|
})
|
||||||
|
|
||||||
|
class ComboInput(WidgetInputV3, io_type=IO.COMBO):
|
||||||
|
'''Combo input (dropdown).'''
|
||||||
|
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||||
|
default: str=None, control_after_generate: bool=None,
|
||||||
|
socketless: bool=None, widgetType: str=None):
|
||||||
|
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
|
||||||
|
self.multiselect = False
|
||||||
|
self.options = options
|
||||||
|
self.control_after_generate = control_after_generate
|
||||||
|
self.default: str
|
||||||
|
|
||||||
|
def as_dict_V1(self):
|
||||||
|
return super().as_dict_V1() | prune_dict({
|
||||||
|
"multiselect": self.multiselect,
|
||||||
|
"options": self.options,
|
||||||
|
"control_after_generate": self.control_after_generate,
|
||||||
|
})
|
||||||
|
|
||||||
|
class MultiselectComboWidget(ComboInput, io_type=IO.COMBO):
|
||||||
|
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
||||||
|
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||||
|
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
||||||
|
socketless: bool=None, widgetType: str=None):
|
||||||
|
super().__init__(id, options, display_name, behavior, tooltip, lazy, default, control_after_generate, socketless, widgetType)
|
||||||
|
self.multiselect = True
|
||||||
|
self.placeholder = placeholder
|
||||||
|
self.chip = chip
|
||||||
|
self.default: list[str]
|
||||||
|
|
||||||
|
def as_dict_V1(self):
|
||||||
|
return super().as_dict_V1() | prune_dict({
|
||||||
|
"multiselect": self.multiselect,
|
||||||
|
"placeholder": self.placeholder,
|
||||||
|
"chip": self.chip,
|
||||||
|
})
|
||||||
|
|
||||||
|
class ImageInput(InputV3, io_type=IO.IMAGE):
|
||||||
|
'''
|
||||||
|
Image input.
|
||||||
|
'''
|
||||||
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
|
||||||
|
super().__init__(id, display_name, behavior, tooltip)
|
||||||
|
|
||||||
|
class MaskInput(InputV3, io_type=IO.MASK):
|
||||||
|
'''
|
||||||
|
Mask input.
|
||||||
|
'''
|
||||||
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
|
||||||
|
super().__init__(id, display_name, behavior, tooltip)
|
||||||
|
|
||||||
|
class LatentInput(InputV3, io_type=IO.LATENT):
|
||||||
|
'''
|
||||||
|
Latent input.
|
||||||
|
'''
|
||||||
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
|
||||||
|
super().__init__(id, display_name, behavior, tooltip)
|
||||||
|
|
||||||
|
class MultitypedInput(InputV3, io_type="COMFY_MULTITYPED_V3"):
|
||||||
|
'''
|
||||||
|
Input that permits more than one input type.
|
||||||
|
'''
|
||||||
|
def __init__(self, id: str, io_types: list[type[IO_V3] | InputV3 | IO |str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None,):
|
||||||
|
super().__init__(id, display_name, behavior, tooltip)
|
||||||
|
self._io_types = io_types
|
||||||
|
|
||||||
|
@property
|
||||||
|
def io_types(self) -> list[type[InputV3]]:
|
||||||
|
'''
|
||||||
|
Returns list of InputV3 class types permitted.
|
||||||
|
'''
|
||||||
|
io_types = []
|
||||||
|
for x in self._io_types:
|
||||||
|
if not is_class(x):
|
||||||
|
io_types.append(type(x))
|
||||||
|
else:
|
||||||
|
io_types.append(x)
|
||||||
|
return io_types
|
||||||
|
|
||||||
|
def get_io_type_V1(self):
|
||||||
|
return ",".join(x.io_type for x in self.io_types)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputV3:
|
||||||
|
def __init__(self, id: str, display_name: str=None, tooltip: str=None,
|
||||||
|
is_output_list=False):
|
||||||
|
self.id = id
|
||||||
|
self.display_name = display_name
|
||||||
|
self.tooltip = tooltip
|
||||||
|
self.is_output_list = is_output_list
|
||||||
|
|
||||||
|
def __init_subclass__(cls, io_type, **kwargs):
|
||||||
|
cls.io_type = io_type
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
|
class IntegerOutput(OutputV3, io_type=IO.INT):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class FloatOutput(OutputV3, io_type=IO.FLOAT):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class StringOutput(OutputV3, io_type=IO.STRING):
|
||||||
|
pass
|
||||||
|
# def __init__(self, id: str, display_name: str=None, tooltip: str=None):
|
||||||
|
# super().__init__(id, display_name, tooltip)
|
||||||
|
|
||||||
|
class ImageOutput(OutputV3, io_type=IO.IMAGE):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class MaskOutput(OutputV3, io_type=IO.MASK):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class LatentOutput(OutputV3, io_type=IO.LATENT):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicInput(InputV3, io_type=None):
|
||||||
|
'''
|
||||||
|
Abstract class for dynamic input registration.
|
||||||
|
'''
|
||||||
|
def __init__(self, io_type: str, id: str, display_name: str=None):
|
||||||
|
super().__init__(io_type, id, display_name)
|
||||||
|
|
||||||
|
class DynamicOutput(OutputV3, io_type=None):
|
||||||
|
'''
|
||||||
|
Abstract class for dynamic output registration.
|
||||||
|
'''
|
||||||
|
def __init__(self, io_type: str, id: str, display_name: str=None):
|
||||||
|
super().__init__(io_type, id, display_name)
|
||||||
|
|
||||||
|
class AutoGrowDynamicInput(DynamicInput, io_type="COMFY_MULTIGROW_V3"):
|
||||||
|
'''
|
||||||
|
Dynamic Input that adds another template_input each time one is provided.
|
||||||
|
|
||||||
|
Additional inputs are forced to have 'InputBehavior.optional'.
|
||||||
|
'''
|
||||||
|
def __init__(self, id: str, template_input: InputV3, min: int=1, max: int=None):
|
||||||
|
super().__init__("AutoGrowDynamicInput", id)
|
||||||
|
self.template_input = template_input
|
||||||
|
if min is not None:
|
||||||
|
assert(min >= 1)
|
||||||
|
if max is not None:
|
||||||
|
assert(max >= 1)
|
||||||
|
self.min = min
|
||||||
|
self.max = max
|
||||||
|
|
||||||
|
class ComboDynamicInput(DynamicInput, io_type="COMFY_COMBODYNAMIC_V3"):
|
||||||
|
def __init__(self, id: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
AutoGrowDynamicInput(id="dynamic", template_input=ImageInput(id="image"))
|
||||||
|
|
||||||
|
|
||||||
|
class Hidden(str, Enum):
|
||||||
|
'''
|
||||||
|
Enumerator for requesting hidden variables in nodes.
|
||||||
|
'''
|
||||||
|
|
||||||
|
unique_id = "UNIQUE_ID"
|
||||||
|
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||||
|
prompt = "PROMPT"
|
||||||
|
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
|
||||||
|
extra_pnginfo = "EXTRA_PNGINFO"
|
||||||
|
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
|
||||||
|
dynprompt = "DYNPROMPT"
|
||||||
|
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
|
||||||
|
auth_token_comfy_org = "AUTH_TOKEN_COMFY_ORG"
|
||||||
|
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
|
||||||
|
api_key_comfy_org = "API_KEY_COMFY_ORG"
|
||||||
|
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeInfoV1:
|
||||||
|
input: dict=None
|
||||||
|
input_order: dict[str, list[str]]=None
|
||||||
|
output: list[str]=None
|
||||||
|
output_is_list: list[bool]=None
|
||||||
|
output_name: list[str]=None
|
||||||
|
output_tooltips: list[str]=None
|
||||||
|
name: str=None
|
||||||
|
display_name: str=None
|
||||||
|
description: str=None
|
||||||
|
python_module: Any=None
|
||||||
|
category: str=None
|
||||||
|
output_node: bool=None
|
||||||
|
deprecated: bool=None
|
||||||
|
experimental: bool=None
|
||||||
|
api_node: bool=None
|
||||||
|
|
||||||
|
|
||||||
|
def as_pruned_dict(dataclass_obj):
|
||||||
|
'''Return dict of dataclass object with pruned None values.'''
|
||||||
|
return prune_dict(asdict(dataclass_obj))
|
||||||
|
|
||||||
|
def prune_dict(d: dict):
|
||||||
|
return {k: v for k,v in d.items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SchemaV3:
|
||||||
|
"""Definition of V3 node properties."""
|
||||||
|
|
||||||
|
node_id: str
|
||||||
|
"""ID of node - should be globally unique. If this is a custom node, add a prefix or postfix to avoid name clashes."""
|
||||||
|
display_name: str = None
|
||||||
|
"""Display name of node."""
|
||||||
|
category: str = "sd"
|
||||||
|
"""The category of the node, as per the "Add Node" menu."""
|
||||||
|
inputs: list[InputV3]=None
|
||||||
|
outputs: list[OutputV3]=None
|
||||||
|
hidden: list[Hidden]=None
|
||||||
|
description: str=""
|
||||||
|
"""Node description, shown as a tooltip when hovering over the node."""
|
||||||
|
is_input_list: bool = False
|
||||||
|
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
|
||||||
|
|
||||||
|
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
|
||||||
|
|
||||||
|
From the docs:
|
||||||
|
|
||||||
|
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
|
||||||
|
|
||||||
|
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||||
|
"""
|
||||||
|
is_output_node: bool=False
|
||||||
|
"""Flags this node as an output node, causing any inputs it requires to be executed.
|
||||||
|
|
||||||
|
If a node is not connected to any output nodes, that node will not be executed. Usage::
|
||||||
|
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
From the docs:
|
||||||
|
|
||||||
|
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
|
||||||
|
|
||||||
|
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
|
||||||
|
"""
|
||||||
|
is_deprecated: bool=False
|
||||||
|
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||||
|
is_experimental: bool=False
|
||||||
|
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||||
|
is_api_node: bool=False
|
||||||
|
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||||
|
|
||||||
|
# class SchemaV3Class:
|
||||||
|
# def __init__(self,
|
||||||
|
# node_id: str,
|
||||||
|
# node_name: str,
|
||||||
|
# category: str,
|
||||||
|
# inputs: list[InputV3],
|
||||||
|
# outputs: list[OutputV3]=None,
|
||||||
|
# hidden: list[Hidden]=None,
|
||||||
|
# description: str="",
|
||||||
|
# is_input_list: bool = False,
|
||||||
|
# is_output_node: bool=False,
|
||||||
|
# is_deprecated: bool=False,
|
||||||
|
# is_experimental: bool=False,
|
||||||
|
# is_api_node: bool=False,
|
||||||
|
# ):
|
||||||
|
# self.node_id = node_id
|
||||||
|
# """ID of node - should be globally unique. If this is a custom node, add a prefix or postfix to avoid name clashes."""
|
||||||
|
# self.node_name = node_name
|
||||||
|
# """Display name of node."""
|
||||||
|
# self.category = category
|
||||||
|
# """The category of the node, as per the "Add Node" menu."""
|
||||||
|
# self.inputs = inputs
|
||||||
|
# self.outputs = outputs
|
||||||
|
# self.hidden = hidden
|
||||||
|
# self.description = description
|
||||||
|
# """Node description, shown as a tooltip when hovering over the node."""
|
||||||
|
# self.is_input_list = is_input_list
|
||||||
|
# """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
|
||||||
|
|
||||||
|
# All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
|
||||||
|
|
||||||
|
# From the docs:
|
||||||
|
|
||||||
|
# A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
|
||||||
|
|
||||||
|
# Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||||
|
# """
|
||||||
|
# self.is_output_node = is_output_node
|
||||||
|
# """Flags this node as an output node, causing any inputs it requires to be executed.
|
||||||
|
|
||||||
|
# If a node is not connected to any output nodes, that node will not be executed. Usage::
|
||||||
|
|
||||||
|
# OUTPUT_NODE = True
|
||||||
|
|
||||||
|
# From the docs:
|
||||||
|
|
||||||
|
# By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
|
||||||
|
|
||||||
|
# Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
|
||||||
|
# """
|
||||||
|
# self.is_deprecated = is_deprecated
|
||||||
|
# """Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||||
|
# self.is_experimental = is_experimental
|
||||||
|
# """Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||||
|
# self.is_api_node = is_api_node
|
||||||
|
# """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||||
|
|
||||||
|
|
||||||
|
class classproperty(object):
|
||||||
|
def __init__(self, f):
|
||||||
|
self.f = f
|
||||||
|
def __get__(self, obj, owner):
|
||||||
|
return self.f(owner)
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyNodeV3(ABC):
|
||||||
|
"""Common base class for all V3 nodes."""
|
||||||
|
|
||||||
|
RELATIVE_PYTHON_MODULE = None
|
||||||
|
#############################################
|
||||||
|
# V1 Backwards Compatibility code
|
||||||
|
#--------------------------------------------
|
||||||
|
_DESCRIPTION = None
|
||||||
|
@classproperty
|
||||||
|
def DESCRIPTION(cls):
|
||||||
|
if cls._DESCRIPTION is None:
|
||||||
|
cls.GET_SCHEMA()
|
||||||
|
return cls._DESCRIPTION
|
||||||
|
|
||||||
|
_CATEGORY = None
|
||||||
|
@classproperty
|
||||||
|
def CATEGORY(cls):
|
||||||
|
if cls._CATEGORY is None:
|
||||||
|
cls.GET_SCHEMA()
|
||||||
|
return cls._CATEGORY
|
||||||
|
|
||||||
|
_EXPERIMENTAL = None
|
||||||
|
@classproperty
|
||||||
|
def EXPERIMENTAL(cls):
|
||||||
|
if cls._EXPERIMENTAL is None:
|
||||||
|
cls.GET_SCHEMA()
|
||||||
|
return cls._EXPERIMENTAL
|
||||||
|
|
||||||
|
_DEPRECATED = None
|
||||||
|
@classproperty
|
||||||
|
def DEPRECATED(cls):
|
||||||
|
if cls._DEPRECATED is None:
|
||||||
|
cls.GET_SCHEMA()
|
||||||
|
return cls._DEPRECATED
|
||||||
|
|
||||||
|
_API_NODE = None
|
||||||
|
@classproperty
|
||||||
|
def API_NODE(cls):
|
||||||
|
if cls._API_NODE is None:
|
||||||
|
cls.GET_SCHEMA()
|
||||||
|
return cls._API_NODE
|
||||||
|
|
||||||
|
_OUTPUT_NODE = None
|
||||||
|
@classproperty
|
||||||
|
def OUTPUT_NODE(cls):
|
||||||
|
if cls._OUTPUT_NODE is None:
|
||||||
|
cls.GET_SCHEMA()
|
||||||
|
return cls._OUTPUT_NODE
|
||||||
|
|
||||||
|
_INPUT_IS_LIST = None
|
||||||
|
@classproperty
|
||||||
|
def INPUT_IS_LIST(cls):
|
||||||
|
if cls._INPUT_IS_LIST is None:
|
||||||
|
cls.GET_SCHEMA()
|
||||||
|
return cls._INPUT_IS_LIST
|
||||||
|
_OUTPUT_IS_LIST = None
|
||||||
|
|
||||||
|
@classproperty
|
||||||
|
def OUTPUT_IS_LIST(cls):
|
||||||
|
if cls._OUTPUT_IS_LIST is None:
|
||||||
|
cls.GET_SCHEMA()
|
||||||
|
return cls._OUTPUT_IS_LIST
|
||||||
|
|
||||||
|
_RETURN_TYPES = None
|
||||||
|
@classproperty
|
||||||
|
def RETURN_TYPES(cls):
|
||||||
|
if cls._RETURN_TYPES is None:
|
||||||
|
cls.GET_SCHEMA()
|
||||||
|
return cls._RETURN_TYPES
|
||||||
|
|
||||||
|
_RETURN_NAMES = None
|
||||||
|
@classproperty
|
||||||
|
def RETURN_NAMES(cls):
|
||||||
|
if cls._RETURN_NAMES is None:
|
||||||
|
cls.GET_SCHEMA()
|
||||||
|
return cls._RETURN_NAMES
|
||||||
|
|
||||||
|
_OUTPUT_TOOLTIPS = None
|
||||||
|
@classproperty
|
||||||
|
def OUTPUT_TOOLTIPS(cls):
|
||||||
|
if cls._OUTPUT_TOOLTIPS is None:
|
||||||
|
cls.GET_SCHEMA()
|
||||||
|
return cls._OUTPUT_TOOLTIPS
|
||||||
|
|
||||||
|
FUNCTION = "execute"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||||
|
schema = cls.DEFINE_SCHEMA()
|
||||||
|
# for V1, make inputs be a dict with potential keys {required, optional, hidden}
|
||||||
|
input = {
|
||||||
|
"required": {}
|
||||||
|
}
|
||||||
|
if schema.inputs:
|
||||||
|
for i in schema.inputs:
|
||||||
|
input.setdefault(i.behavior.value, {})[i.id] = (i.get_io_type_V1(), i.as_dict_V1())
|
||||||
|
if schema.hidden:
|
||||||
|
for hidden in schema.hidden:
|
||||||
|
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||||
|
return input
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def GET_SCHEMA(cls) -> SchemaV3:
|
||||||
|
schema = cls.DEFINE_SCHEMA()
|
||||||
|
if cls._DESCRIPTION is None:
|
||||||
|
cls._DESCRIPTION = schema.description
|
||||||
|
if cls._CATEGORY is None:
|
||||||
|
cls._CATEGORY = schema.category
|
||||||
|
if cls._EXPERIMENTAL is None:
|
||||||
|
cls._EXPERIMENTAL = schema.is_experimental
|
||||||
|
if cls._DEPRECATED is None:
|
||||||
|
cls._DEPRECATED = schema.is_deprecated
|
||||||
|
if cls._API_NODE is None:
|
||||||
|
cls._API_NODE = schema.is_api_node
|
||||||
|
if cls._OUTPUT_NODE is None:
|
||||||
|
cls._OUTPUT_NODE = schema.is_output_node
|
||||||
|
if cls._INPUT_IS_LIST is None:
|
||||||
|
cls._INPUT_IS_LIST = schema.is_input_list
|
||||||
|
|
||||||
|
if cls._RETURN_TYPES is None:
|
||||||
|
output = []
|
||||||
|
output_name = []
|
||||||
|
output_is_list = []
|
||||||
|
output_tooltips = []
|
||||||
|
if schema.outputs:
|
||||||
|
for o in schema.outputs:
|
||||||
|
output.append(o.io_type)
|
||||||
|
output_name.append(o.display_name if o.display_name else o.io_type)
|
||||||
|
output_is_list.append(o.is_output_list)
|
||||||
|
output_tooltips.append(o.tooltip if o.tooltip else None)
|
||||||
|
|
||||||
|
cls._RETURN_TYPES = output
|
||||||
|
cls._RETURN_NAMES = output_name
|
||||||
|
cls._OUTPUT_IS_LIST = output_is_list
|
||||||
|
cls._OUTPUT_TOOLTIPS = output_tooltips
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def GET_NODE_INFO_V1(cls) -> dict[str, Any]:
|
||||||
|
schema = cls.GET_SCHEMA()
|
||||||
|
# get V1 inputs
|
||||||
|
input = cls.INPUT_TYPES()
|
||||||
|
|
||||||
|
# create separate lists from output fields
|
||||||
|
output = []
|
||||||
|
output_is_list = []
|
||||||
|
output_name = []
|
||||||
|
output_tooltips = []
|
||||||
|
if schema.outputs:
|
||||||
|
for o in schema.outputs:
|
||||||
|
output.append(o.io_type)
|
||||||
|
output_is_list.append(o.is_output_list)
|
||||||
|
output_name.append(o.display_name if o.display_name else o.io_type)
|
||||||
|
output_tooltips.append(o.tooltip if o.tooltip else None)
|
||||||
|
|
||||||
|
info = NodeInfoV1(
|
||||||
|
input=input,
|
||||||
|
input_order={key: list(value.keys()) for (key, value) in input.items()},
|
||||||
|
output=output,
|
||||||
|
output_is_list=output_is_list,
|
||||||
|
output_name=output_name,
|
||||||
|
output_tooltips=output_tooltips,
|
||||||
|
name=schema.node_id,
|
||||||
|
display_name=schema.display_name,
|
||||||
|
category=schema.category,
|
||||||
|
description=schema.description,
|
||||||
|
output_node=schema.is_output_node,
|
||||||
|
deprecated=schema.is_deprecated,
|
||||||
|
experimental=schema.is_experimental,
|
||||||
|
api_node=schema.is_api_node,
|
||||||
|
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes")
|
||||||
|
)
|
||||||
|
return asdict(info)
|
||||||
|
#--------------------------------------------
|
||||||
|
#############################################
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
|
||||||
|
schema = cls.GET_SCHEMA()
|
||||||
|
# TODO: finish
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def DEFINE_SCHEMA(cls) -> SchemaV3:
|
||||||
|
"""
|
||||||
|
Override this function with one that returns a SchemaV3 instance.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
DEFINE_SCHEMA = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if self.DEFINE_SCHEMA is None:
|
||||||
|
raise Exception("No DEFINE_SCHEMA function was defined for this node.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def execute(self, **kwargs) -> NodeOutput:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# class ReturnedInputs:
|
||||||
|
# def __init__(self):
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# class ReturnedOutputs:
|
||||||
|
# def __init__(self):
|
||||||
|
# pass
|
||||||
|
|
||||||
|
|
||||||
|
class NodeOutput:
|
||||||
|
'''
|
||||||
|
Standardized output of a node; can pass in any number of args and/or a UIOutput into 'ui' kwarg.
|
||||||
|
'''
|
||||||
|
def __init__(self, *args: Any, ui: UIOutput | dict=None, expand: dict=None, block_execution: str=None, **kwargs):
|
||||||
|
self.args = args
|
||||||
|
self.ui = ui
|
||||||
|
self.expand = expand
|
||||||
|
self.block_execution = block_execution
|
||||||
|
|
||||||
|
@property
|
||||||
|
def result(self):
|
||||||
|
return self.args if len(self.args) > 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
class SavedResult:
|
||||||
|
def __init__(self, filename: str, subfolder: str, type: Literal["input", "output", "temp"]):
|
||||||
|
self.filename = filename
|
||||||
|
self.subfolder = subfolder
|
||||||
|
self.type = type
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return {
|
||||||
|
"filename": self.filename,
|
||||||
|
"subfolder": self.subfolder,
|
||||||
|
"type": self.type
|
||||||
|
}
|
||||||
|
|
||||||
|
class UIOutput(ABC):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
... # TODO: finish
|
||||||
|
|
||||||
|
class UIImages(UIOutput):
|
||||||
|
def __init__(self, values: list[SavedResult | dict], animated=False, **kwargs):
|
||||||
|
self.values = values
|
||||||
|
self.animated = animated
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
|
||||||
|
return {
|
||||||
|
"images": values,
|
||||||
|
"animated": (self.animated,)
|
||||||
|
}
|
||||||
|
|
||||||
|
class UILatents(UIOutput):
|
||||||
|
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||||
|
self.values = values
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
|
||||||
|
return {
|
||||||
|
"latents": values,
|
||||||
|
}
|
||||||
|
|
||||||
|
class UIAudio(UIOutput):
|
||||||
|
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||||
|
self.values = values
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
|
||||||
|
return {
|
||||||
|
"audio": values,
|
||||||
|
}
|
||||||
|
|
||||||
|
class UI3D(UIOutput):
|
||||||
|
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||||
|
self.values = values
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
|
||||||
|
return {
|
||||||
|
"3d": values,
|
||||||
|
}
|
||||||
|
|
||||||
|
class UIText(UIOutput):
|
||||||
|
def __init__(self, value: str, **kwargs):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return {"text": (self.value,)}
|
||||||
|
|
||||||
|
|
||||||
|
class TestNode(ComfyNodeV3):
|
||||||
|
SCHEMA = SchemaV3(
|
||||||
|
node_id="TestNode_v3",
|
||||||
|
display_name="Test Node (V3)",
|
||||||
|
category="v3_test",
|
||||||
|
inputs=[IntegerInput("my_int"),
|
||||||
|
#AutoGrowDynamicInput("growing", ImageInput),
|
||||||
|
MaskInput("thing"),
|
||||||
|
],
|
||||||
|
outputs=[ImageOutput("image_output")],
|
||||||
|
hidden=[Hidden.api_key_comfy_org, Hidden.auth_token_comfy_org, Hidden.unique_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
# @classmethod
|
||||||
|
# def GET_SCHEMA(cls):
|
||||||
|
# return cls.SCHEMA
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def DEFINE_SCHEMA(cls):
|
||||||
|
return cls.SCHEMA
|
||||||
|
|
||||||
|
def execute(**kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("hello there")
|
||||||
|
inputs: list[InputV3] = [
|
||||||
|
IntegerInput("my_int"),
|
||||||
|
CustomInput("xyz", "XYZ"),
|
||||||
|
CustomInput("model1", "MODEL_M"),
|
||||||
|
ImageInput("my_image"),
|
||||||
|
FloatInput("my_float"),
|
||||||
|
MultitypedInput("my_inputs", [CustomType("MODEL_M"), CustomType("XYZ")]),
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs: list[OutputV3] = [
|
||||||
|
ImageOutput("image"),
|
||||||
|
CustomOutput("xyz", "XYZ")
|
||||||
|
]
|
||||||
|
|
||||||
|
for c in inputs:
|
||||||
|
if isinstance(c, MultitypedInput):
|
||||||
|
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}, {[x.io_type for x in c.io_types]}")
|
||||||
|
print(c.get_io_type_V1())
|
||||||
|
else:
|
||||||
|
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
|
||||||
|
|
||||||
|
for c in outputs:
|
||||||
|
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
|
||||||
|
|
||||||
|
zz = TestNode()
|
||||||
|
print(zz.GET_NODE_INFO_V1())
|
||||||
|
|
||||||
|
# aa = NodeInfoV1()
|
||||||
|
# print(asdict(aa))
|
||||||
|
# print(as_pruned_dict(aa))
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
from inspect import cleandoc
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from comfy.model_patcher import ModelPatcher
|
|
||||||
import comfy.multigpu
|
|
||||||
|
|
||||||
|
|
||||||
class MultiGPUWorkUnitsNode:
|
|
||||||
"""
|
|
||||||
Prepares model to have sampling accelerated via splitting work units.
|
|
||||||
|
|
||||||
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
|
|
||||||
|
|
||||||
Other than those exceptions, this node can be placed in any order.
|
|
||||||
"""
|
|
||||||
|
|
||||||
NodeId = "MultiGPU_WorkUnits"
|
|
||||||
NodeName = "MultiGPU Work Units"
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("MODEL",),
|
|
||||||
"max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"gpu_options": ("GPU_OPTIONS",)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
|
||||||
FUNCTION = "init_multigpu"
|
|
||||||
CATEGORY = "advanced/multigpu"
|
|
||||||
DESCRIPTION = cleandoc(__doc__)
|
|
||||||
|
|
||||||
def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
|
|
||||||
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True)
|
|
||||||
return (model,)
|
|
||||||
|
|
||||||
class MultiGPUOptionsNode:
|
|
||||||
"""
|
|
||||||
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
|
|
||||||
"""
|
|
||||||
|
|
||||||
NodeId = "MultiGPU_Options"
|
|
||||||
NodeName = "MultiGPU Options"
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"device_index": ("INT", {"default": 0, "min": 0, "max": 64}),
|
|
||||||
"relative_speed": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.01})
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"gpu_options": ("GPU_OPTIONS",)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("GPU_OPTIONS",)
|
|
||||||
FUNCTION = "create_gpu_options"
|
|
||||||
CATEGORY = "advanced/multigpu"
|
|
||||||
DESCRIPTION = cleandoc(__doc__)
|
|
||||||
|
|
||||||
def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
|
|
||||||
if not gpu_options:
|
|
||||||
gpu_options = comfy.multigpu.GPUOptionsGroup()
|
|
||||||
gpu_options.clone()
|
|
||||||
|
|
||||||
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
|
|
||||||
gpu_options.add(opt)
|
|
||||||
|
|
||||||
return (gpu_options,)
|
|
||||||
|
|
||||||
|
|
||||||
node_list = [
|
|
||||||
MultiGPUWorkUnitsNode,
|
|
||||||
MultiGPUOptionsNode
|
|
||||||
]
|
|
||||||
NODE_CLASS_MAPPINGS = {}
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
|
||||||
|
|
||||||
for node in node_list:
|
|
||||||
NODE_CLASS_MAPPINGS[node.NodeId] = node
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName
|
|
||||||
67
comfy_extras/nodes_v3_test.py
Normal file
67
comfy_extras/nodes_v3_test.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy_api.v3.io import (
|
||||||
|
ComfyNodeV3, SchemaV3, CustomType, CustomInput, CustomOutput, InputBehavior, NumberDisplay,
|
||||||
|
IntegerInput, MaskInput, ImageInput, ComboDynamicInput, NodeOutput,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class V3TestNode(ComfyNodeV3):
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def DEFINE_SCHEMA(cls):
|
||||||
|
return SchemaV3(
|
||||||
|
node_id="V3TestNode1",
|
||||||
|
display_name="V3 Test Node (1djekjd)",
|
||||||
|
description="This is a funky V3 node test.",
|
||||||
|
category="v3 nodes",
|
||||||
|
inputs=[
|
||||||
|
IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display_mode=NumberDisplay.slider),
|
||||||
|
MaskInput("mask", behavior=InputBehavior.optional),
|
||||||
|
ImageInput("image", display_name="new_image"),
|
||||||
|
# IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display=NumberDisplay.slider, ),
|
||||||
|
# ComboDynamicInput("mask", behavior=InputBehavior.optional),
|
||||||
|
# IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display=NumberDisplay.slider,
|
||||||
|
# dependent_inputs=[ComboDynamicInput("mask", behavior=InputBehavior.optional)],
|
||||||
|
# dependent_values=[lambda my_value: IO.STRING if my_value < 5 else IO.NUMBER],
|
||||||
|
# ),
|
||||||
|
# ["option1", "option2". "option3"]
|
||||||
|
# ComboDynamicInput["sdfgjhl", [ComboDynamicOptions("option1", [IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display=NumberDisplay.slider, ImageInput(), MaskInput(), String()]),
|
||||||
|
# CombyDynamicOptons("option2", [])
|
||||||
|
# ]]
|
||||||
|
],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute(self, some_int: int, image: torch.Tensor, mask: torch.Tensor=None, **kwargs):
|
||||||
|
a = NodeOutput(1)
|
||||||
|
aa = NodeOutput(1, "hellothere")
|
||||||
|
ab = NodeOutput(1, "hellothere", ui={"lol": "jk"})
|
||||||
|
b = NodeOutput()
|
||||||
|
c = NodeOutput(ui={"lol": "jk"})
|
||||||
|
return NodeOutput()
|
||||||
|
return NodeOutput(1)
|
||||||
|
return NodeOutput(1, block_execution="Kill yourself")
|
||||||
|
return ()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST: list[ComfyNodeV3] = [
|
||||||
|
V3TestNode,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# NODE_CLASS_MAPPINGS = {}
|
||||||
|
# NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||||
|
# for node in NODES_LIST:
|
||||||
|
# schema = node.GET_SCHEMA()
|
||||||
|
# NODE_CLASS_MAPPINGS[schema.node_id] = node
|
||||||
|
# if schema.display_name:
|
||||||
|
# NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
|
||||||
17
execution.py
17
execution.py
@@ -17,6 +17,7 @@ from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt,
|
|||||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||||
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
|
from comfy_api.v3.io import NodeOutput
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
SUCCESS = 0
|
SUCCESS = 0
|
||||||
@@ -242,6 +243,22 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
|
|||||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||||
results.append(result)
|
results.append(result)
|
||||||
subgraph_results.append((None, result))
|
subgraph_results.append((None, result))
|
||||||
|
elif isinstance(r, NodeOutput):
|
||||||
|
if r.ui is not None:
|
||||||
|
uis.append(r.ui.as_dict())
|
||||||
|
if r.expand is not None:
|
||||||
|
has_subgraph = True
|
||||||
|
new_graph = r.expand
|
||||||
|
result = r.result
|
||||||
|
if r.block_execution is not None:
|
||||||
|
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||||
|
subgraph_results.append((new_graph, result))
|
||||||
|
elif r.result is not None:
|
||||||
|
result = r.result
|
||||||
|
if r.block_execution is not None:
|
||||||
|
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||||
|
results.append(result)
|
||||||
|
subgraph_results.append((None, result))
|
||||||
else:
|
else:
|
||||||
if isinstance(r, ExecutionBlocker):
|
if isinstance(r, ExecutionBlocker):
|
||||||
r = tuple([r] * len(obj.RETURN_TYPES))
|
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||||
|
|||||||
17
nodes.py
17
nodes.py
@@ -26,6 +26,7 @@ import comfy.sd
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.controlnet
|
import comfy.controlnet
|
||||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
||||||
|
from comfy_api.v3.io import ComfyNodeV3
|
||||||
|
|
||||||
import comfy.clip_vision
|
import comfy.clip_vision
|
||||||
|
|
||||||
@@ -2129,6 +2130,7 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
|
|||||||
if os.path.isdir(web_dir):
|
if os.path.isdir(web_dir):
|
||||||
EXTENSION_WEB_DIRS[module_name] = web_dir
|
EXTENSION_WEB_DIRS[module_name] = web_dir
|
||||||
|
|
||||||
|
# V1 node definition
|
||||||
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
|
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
|
||||||
for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
|
for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
|
||||||
if name not in ignore:
|
if name not in ignore:
|
||||||
@@ -2137,8 +2139,19 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
|
|||||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
return True
|
return True
|
||||||
|
# V3 node definition
|
||||||
|
elif getattr(module, "NODES_LIST", None) is not None:
|
||||||
|
for node_cls in module.NODES_LIST:
|
||||||
|
node_cls: ComfyNodeV3
|
||||||
|
schema = node_cls.GET_SCHEMA()
|
||||||
|
if schema.node_id not in ignore:
|
||||||
|
NODE_CLASS_MAPPINGS[schema.node_id] = node_cls
|
||||||
|
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
|
||||||
|
if schema.display_name is not None:
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
|
||||||
|
return True
|
||||||
else:
|
else:
|
||||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).")
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(traceback.format_exc())
|
logging.warning(traceback.format_exc())
|
||||||
@@ -2241,7 +2254,6 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_mahiro.py",
|
"nodes_mahiro.py",
|
||||||
"nodes_lt.py",
|
"nodes_lt.py",
|
||||||
"nodes_hooks.py",
|
"nodes_hooks.py",
|
||||||
"nodes_multigpu.py",
|
|
||||||
"nodes_load_3d.py",
|
"nodes_load_3d.py",
|
||||||
"nodes_cosmos.py",
|
"nodes_cosmos.py",
|
||||||
"nodes_video.py",
|
"nodes_video.py",
|
||||||
@@ -2259,6 +2271,7 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_ace.py",
|
"nodes_ace.py",
|
||||||
"nodes_string.py",
|
"nodes_string.py",
|
||||||
"nodes_camera_trajectory.py",
|
"nodes_camera_trajectory.py",
|
||||||
|
"nodes_v3_test.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import comfy.model_management
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
from comfyui_version import __version__
|
from comfyui_version import __version__
|
||||||
from app.frontend_management import FrontendManager
|
from app.frontend_management import FrontendManager
|
||||||
|
from comfy_api.v3.io import ComfyNodeV3
|
||||||
|
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
@@ -555,6 +556,8 @@ class PromptServer():
|
|||||||
|
|
||||||
def node_info(node_class):
|
def node_info(node_class):
|
||||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||||
|
if isinstance(obj_class, ComfyNodeV3):
|
||||||
|
return obj_class.GET_NODE_INFO_V1()
|
||||||
info = {}
|
info = {}
|
||||||
info['input'] = obj_class.INPUT_TYPES()
|
info['input'] = obj_class.INPUT_TYPES()
|
||||||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||||
|
|||||||
Reference in New Issue
Block a user