Compare commits

..

79 Commits

Author SHA1 Message Date
kosinkadink1@gmail.com
0336b0ace8 Merge branch 'master' into worksplit-multigpu 2025-06-01 02:39:26 -07:00
kosinkadink1@gmail.com
8ae25235ec Merge branch 'master' into worksplit-multigpu 2025-05-21 12:01:27 -07:00
Jedrzej Kosinski
9726eac475 Merge branch 'master' into worksplit-multigpu 2025-05-12 19:29:13 -05:00
Jedrzej Kosinski
272e8d42c1 Merge branch 'master' into worksplit-multigpu 2025-04-22 22:40:00 -05:00
Jedrzej Kosinski
6211d2be5a Merge branch 'master' into worksplit-multigpu 2025-04-19 17:36:23 -05:00
Jedrzej Kosinski
8be711715c Make unload_all_models account for all devices 2025-04-19 17:35:54 -05:00
Jedrzej Kosinski
b5cccf1325 Merge branch 'master' into worksplit-multigpu 2025-04-18 15:39:34 -05:00
Jedrzej Kosinski
2a54a904f4 Merge branch 'master' into worksplit-multigpu 2025-04-16 19:26:48 -05:00
Jedrzej Kosinski
ed6f92c975 Merge branch 'master' into worksplit-multigpu 2025-04-16 16:53:57 -05:00
Jedrzej Kosinski
adc66c0698 Merge branch 'master' into worksplit-multigpu 2025-04-16 14:23:56 -05:00
Jedrzej Kosinski
ccd5c01e5a Merge branch 'master' into worksplit-multigpu 2025-04-09 09:17:12 -05:00
Jedrzej Kosinski
2fa9affcc1 Merge branch 'master' into worksplit-multigpu 2025-04-08 22:52:17 -05:00
Jedrzej Kosinski
407a5a656f Rollback core of last commit due to weird behavior 2025-03-28 02:48:11 -05:00
kosinkadink1@gmail.com
9ce9ff8ef8 Allow chained MultiGPU Work Unit nodes to affect max_gpus present on ModelPatcher clone 2025-03-28 15:29:44 +08:00
Jedrzej Kosinski
63567c0ce8 Merge branch 'master' into worksplit-multigpu 2025-03-27 22:36:46 -05:00
Jedrzej Kosinski
a786ce5ead Merge branch 'master' into worksplit-multigpu 2025-03-26 22:26:26 -05:00
Jedrzej Kosinski
4879b47648 Merge branch 'master' into worksplit-multigpu 2025-03-18 22:19:32 -05:00
Jedrzej Kosinski
5ccec33c22 Merge branch 'worksplit-multigpu' of https://github.com/comfyanonymous/ComfyUI into worksplit-multigpu 2025-03-17 14:27:39 -05:00
Jedrzej Kosinski
219d3cd0d0 Merge branch 'master' into worksplit-multigpu 2025-03-17 14:26:35 -05:00
Jedrzej Kosinski
c4ba399475 Merge branch 'master' into worksplit-multigpu 2025-03-15 09:12:09 -05:00
Jedrzej Kosinski
cc928a786d Merge branch 'master' into worksplit-multigpu 2025-03-13 20:59:11 -05:00
Jedrzej Kosinski
6e144b98c4 Merge branch 'master' into worksplit-multigpu 2025-03-09 00:00:38 -06:00
Jedrzej Kosinski
6dca17bd2d Satisfy ruff linting 2025-03-03 23:08:29 -06:00
Jedrzej Kosinski
5080105c23 Merge branch 'master' into worksplit-multigpu 2025-03-03 22:56:53 -06:00
Jedrzej Kosinski
093914a247 Made MultiGPU Work Units node more robust by forcing ModelPatcher clones to match at sample time, reuse loaded MultiGPU clones, finalize MultiGPU Work Units node ID and name, small refactors/cleanup of logging and multigpu-related code 2025-03-03 22:56:13 -06:00
Jedrzej Kosinski
605893d3cf Merge branch 'master' into worksplit-multigpu 2025-02-24 19:23:16 -06:00
Jedrzej Kosinski
048f4f0b3a Merge branch 'master' into worksplit-multigpu 2025-02-17 19:35:58 -06:00
Jedrzej Kosinski
d2504fb701 Merge branch 'master' into worksplit-multigpu 2025-02-11 22:34:51 -06:00
Jedrzej Kosinski
b03763bca6 Merge branch 'multigpu_support' into worksplit-multigpu 2025-02-07 13:27:49 -06:00
Jedrzej Kosinski
476aa79b64 Let --cuda-device take in a string to allow multiple devices (or device order) to be chosen, print available devices on startup, potentially support MultiGPU Intel and Ascend setups 2025-02-06 08:44:07 -06:00
Jedrzej Kosinski
441cfd1a7a Merge branch 'master' into multigpu_support 2025-02-06 08:10:48 -06:00
Jedrzej Kosinski
99a5c1068a Merge branch 'master' into multigpu_support 2025-02-02 03:19:18 -06:00
Jedrzej Kosinski
02747cde7d Carry over change from _calc_cond_batch into _calc_cond_batch_multigpu 2025-01-29 11:10:23 -06:00
Jedrzej Kosinski
0b3233b4e2 Merge remote-tracking branch 'origin/master' into multigpu_support 2025-01-28 06:11:07 -06:00
Jedrzej Kosinski
eda866bf51 Extracted multigpu core code into multigpu.py, added load_balance_devices to get subdivision of work based on available devices and splittable work item count, added MultiGPU Options nodes to set relative_speed of specific devices; does not change behavior yet 2025-01-27 06:25:48 -06:00
Jedrzej Kosinski
e3298b84de Create proper MultiGPU Initialize node, create gpu_options to create scaffolding for asymmetrical GPU support 2025-01-26 09:34:20 -06:00
Jedrzej Kosinski
c7feef9060 Cast transformer_options for multigpu 2025-01-26 05:29:27 -06:00
Jedrzej Kosinski
51af7fa1b4 Fix multigpu ControlBase get_models and cleanup calls to avoid multiple calls of functions on multigpu_clones versions of controlnets 2025-01-25 06:05:01 -06:00
Jedrzej Kosinski
46969c380a Initial MultiGPU support for controlnets 2025-01-24 05:39:38 -06:00
Jedrzej Kosinski
5db4277449 Make sure additional_models are unloaded as well when perform 2025-01-23 19:06:05 -06:00
Jedrzej Kosinski
02a4d0ad7d Added unload_model_and_clones to model_management.py to allow unloading only relevant models 2025-01-23 01:20:00 -06:00
Jedrzej Kosinski
ef137ac0b6 Merge branch 'multigpu_support' of https://github.com/kosinkadink/ComfyUI into multigpu_support 2025-01-20 04:34:39 -06:00
Jedrzej Kosinski
328d4f16a9 Make WeightHooks compatible with MultiGPU, clean up some code 2025-01-20 04:34:26 -06:00
Jedrzej Kosinski
bdbcb85b8d Merge branch 'multigpu_support' of https://github.com/Kosinkadink/ComfyUI into multigpu_support 2025-01-20 00:51:42 -06:00
Jedrzej Kosinski
6c9e94bae7 Merge branch 'master' into multigpu_support 2025-01-20 00:51:37 -06:00
Jedrzej Kosinski
bfce723311 Initial work on multigpu_clone function, which will account for additional_models getting cloned 2025-01-17 03:31:28 -06:00
Jedrzej Kosinski
31f5458938 Merge branch 'master' into multigpu_support 2025-01-16 18:25:05 -06:00
Jedrzej Kosinski
2145a202eb Merge branch 'master' into multigpu_support 2025-01-15 19:58:28 -06:00
Jedrzej Kosinski
25818dc848 Added a 'max_gpus' input 2025-01-14 13:45:14 -06:00
Jedrzej Kosinski
198953cd08 Add nodes_multigpu.py to loaded nodes 2025-01-14 12:24:55 -06:00
Jedrzej Kosinski
ec16ee2f39 Merge branch 'master' into multigpu_support 2025-01-13 20:21:06 -06:00
Jedrzej Kosinski
d5088072fb Make test node for multigpu instead of storing it in just a local __init__.py 2025-01-13 20:20:25 -06:00
Jedrzej Kosinski
8d4b50158e Merge branch 'master' into multigpu_support 2025-01-11 20:16:42 -06:00
Jedrzej Kosinski
e88c6c03ff Fix cond_cat to not try to cast anything that doesn't have a 'to' function 2025-01-10 23:05:24 -06:00
Jedrzej Kosinski
d3cf2b7b24 Merge branch 'comfyanonymous:master' into multigpu_support 2025-01-10 20:24:37 -06:00
Jedrzej Kosinski
7448f02b7c Initial proof of concept of giving splitting cond sampling between multiple GPUs 2025-01-08 03:33:05 -06:00
Jedrzej Kosinski
871258aa72 Add get_all_torch_devices to get detected devices intended for current torch hardware device 2025-01-07 21:06:03 -06:00
Jedrzej Kosinski
66838ebd39 Merge branch 'comfyanonymous:master' into multigpu_support 2025-01-07 20:11:27 -06:00
Jedrzej Kosinski
7333281698 Clean up a typehint 2025-01-07 02:58:59 -06:00
Jedrzej Kosinski
3cd4c5cb0a Rename AddModelsHooks to AdditionalModelsHook, rename SetInjectionsHook to InjectionsHook (not yet implemented, but at least getting the naming figured out) 2025-01-07 02:22:49 -06:00
Jedrzej Kosinski
11c6d56037 Merge branch 'master' into hooks_part2 2025-01-07 01:01:53 -06:00
Jedrzej Kosinski
216fea15ee Made TransformerOptionsHook contribute to registered hooks properly, added some doc strings and removed a so-far unused variable 2025-01-07 00:59:18 -06:00
Jedrzej Kosinski
58bf8815c8 Add a get_injections function to ModelPatcher 2025-01-06 20:34:30 -06:00
Jedrzej Kosinski
1b38f5bf57 removed 4 whitespace lines to satisfy Ruff, 2025-01-06 17:11:12 -06:00
Jedrzej Kosinski
2724ac4a60 Merge branch 'master' into hooks_part2 2025-01-06 17:04:24 -06:00
Jedrzej Kosinski
f48f90e471 Make hook_scope functional for TransformerOptionsHook 2025-01-06 02:23:04 -06:00
Jedrzej Kosinski
6463c39ce0 Merge branch 'master' into hooks_part2 2025-01-06 01:28:26 -06:00
Jedrzej Kosinski
0a7e2ae787 Filter only registered hooks on self.conds in CFGGuider.sample 2025-01-06 01:04:29 -06:00
Jedrzej Kosinski
03a97b604a Fix performance of hooks when hooks are appended via Cond Pair Set Props nodes by properly caching between positive and negative conds, make hook_patches_backup behave as intended (in the case that something pre-registers WeightHooks on the ModelPatcher instead of registering it at sample time) 2025-01-06 01:03:59 -06:00
Jedrzej Kosinski
4446c86052 Made hook clone code sane, made clear ObjectPatchHook and SetInjectionsHook are not yet operational 2025-01-05 22:25:51 -06:00
Jedrzej Kosinski
8270ff312f Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time 2025-01-05 21:07:02 -06:00
Jedrzej Kosinski
db2d7ad9ba Merge branch 'add_sample_sigmas' into hooks_part2 2025-01-05 15:45:13 -06:00
Jedrzej Kosinski
6620d86318 In inner_sample, change "sigmas" to "sampler_sigmas" in transformer_options to not conflict with the "sigmas" that will overwrite "sigmas" in _calc_cond_batch 2025-01-05 15:26:22 -06:00
Jedrzej Kosinski
111fd0cadf Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type 2025-01-04 02:04:07 -06:00
Jedrzej Kosinski
776aa734e1 Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options) 2025-01-04 01:02:21 -06:00
Jedrzej Kosinski
5a2ad032cb Cleaned up hooks.py, refactored Hook.should_register and add_hook_patches to use target_dict instead of target so that more information can be provided about the current execution environment if needed 2025-01-03 20:02:27 -06:00
Jedrzej Kosinski
d44295ef71 Merge branch 'master' into hooks_part2 2025-01-03 18:28:31 -06:00
Jedrzej Kosinski
bf21be066f Merge branch 'master' into hooks_part2 2024-12-30 14:16:22 -06:00
Jedrzej Kosinski
72bbf49349 Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas' 2024-12-29 15:49:09 -06:00
14 changed files with 763 additions and 988 deletions

View File

@@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")

View File

@@ -15,13 +15,14 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import torch
from enum import Enum
import math
import os
import logging
import copy
import comfy.utils
import comfy.model_management
import comfy.model_detection
@@ -36,7 +37,7 @@ import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from comfy.hooks import HookGroup
@@ -63,6 +64,18 @@ class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
class ControlIsolation:
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
def __init__(self, control: ControlBase):
self.control = control
self.orig_previous_controlnet = control.previous_controlnet
def __enter__(self):
self.control.previous_controlnet = None
def __exit__(self, *args):
self.control.previous_controlnet = self.orig_previous_controlnet
class ControlBase:
def __init__(self):
self.cond_hint_original = None
@@ -76,7 +89,7 @@ class ControlBase:
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
self.previous_controlnet = None
self.previous_controlnet: Union[ControlBase, None] = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
@@ -84,6 +97,7 @@ class ControlBase:
self.extra_concat = None
self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a
self.multigpu_clones: dict[torch.device, ControlBase] = {}
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
@@ -110,17 +124,38 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
for device_cnet in self.multigpu_clones.values():
with ControlIsolation(device_cnet):
device_cnet.cleanup()
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
def get_models(self):
out = []
for device_cnet in self.multigpu_clones.values():
out += device_cnet.get_models_only_self()
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def get_models_only_self(self):
'Calls get_models, but temporarily sets previous_controlnet to None.'
with ControlIsolation(self):
return self.get_models()
def get_instance_for_device(self, device):
'Returns instance of this Control object intended for selected device.'
return self.multigpu_clones.get(device, self)
def deepclone_multigpu(self, load_device, autoregister=False):
'''
Create deep clone of Control object where model(s) is set to other devices.
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
'''
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
@@ -129,7 +164,7 @@ class ControlBase:
out += self.previous_controlnet.get_extra_hooks()
return out
def copy_to(self, c):
def copy_to(self, c: ControlBase):
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
@@ -280,6 +315,14 @@ class ControlNet(ControlBase):
self.copy_to(c)
return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.control_model = copy.deepcopy(c.control_model)
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
if autoregister:
self.multigpu_clones[load_device] = c
return c
def get_models(self):
out = super().get_models()
out.append(self.control_model_wrapped)
@@ -805,6 +848,14 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.t2i_model = copy.deepcopy(c.t2i_model)
c.device = load_device
if autoregister:
self.multigpu_clones[load_device] = c
return c
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8
upscale_algorithm = 'nearest-exact'

View File

@@ -15,6 +15,7 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import psutil
import logging
@@ -26,6 +27,10 @@ import platform
import weakref
import gc
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
@@ -171,6 +176,25 @@ def get_torch_device():
else:
return torch.device(torch.cuda.current_device())
def get_all_torch_devices(exclude_current=False):
global cpu_state
devices = []
if cpu_state == CPUState.GPU:
if is_nvidia():
for i in range(torch.cuda.device_count()):
devices.append(torch.device(i))
elif is_intel_xpu():
for i in range(torch.xpu.device_count()):
devices.append(torch.device(i))
elif is_ascend_npu():
for i in range(torch.npu.device_count()):
devices.append(torch.device(i))
else:
devices.append(get_torch_device())
if exclude_current:
devices.remove(get_torch_device())
return devices
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None:
@@ -387,9 +411,13 @@ try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except:
logging.warning("Could not pick default device.")
try:
for device in get_all_torch_devices(exclude_current=True):
logging.info("Device: {}".format(get_torch_device_name(device)))
except:
pass
current_loaded_models = []
current_loaded_models: list[LoadedModel] = []
def module_size(module):
module_mem = 0
@@ -400,7 +428,7 @@ def module_size(module):
return module_mem
class LoadedModel:
def __init__(self, model):
def __init__(self, model: ModelPatcher):
self._set_model(model)
self.device = model.load_device
self.real_model = None
@@ -408,7 +436,7 @@ class LoadedModel:
self.model_finalizer = None
self._patcher_finalizer = None
def _set_model(self, model):
def _set_model(self, model: ModelPatcher):
self._model = weakref.ref(model)
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
@@ -1300,8 +1328,34 @@ def soft_empty_cache(force=False):
torch.cuda.ipc_collect()
def unload_all_models():
free_memory(1e30, get_torch_device())
for device in get_all_torch_devices():
free_memory(1e30, device)
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
'Unload only model and its clones - primarily for multigpu cloning purposes.'
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
additional_models = []
if unload_additional_models:
additional_models = model.get_nested_additional_models()
keep_loaded = []
for loaded_model in initial_keep_loaded:
if loaded_model.model is not None:
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
continue
# check additional models if they are a match
skip = False
for add_model in additional_models:
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
skip = True
break
if skip:
continue
keep_loaded.append(loaded_model)
if not all_devices:
free_memory(1e30, get_torch_device(), keep_loaded)
else:
for device in get_all_torch_devices():
free_memory(1e30, device, keep_loaded)
#TODO: might be cleaner to put this somewhere else
import threading

View File

@@ -84,12 +84,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
def create_hook_patches_clone(orig_hook_patches):
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
new_hook_patches = {}
for hook_ref in orig_hook_patches:
new_hook_patches[hook_ref] = {}
for k in orig_hook_patches[hook_ref]:
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
if copy_tuples:
for i in range(len(new_hook_patches[hook_ref][k])):
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
return new_hook_patches
def wipe_lowvram_weight(m):
@@ -240,6 +243,9 @@ class ModelPatcher:
self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
self.is_multigpu_base_clone = False
self.clone_base_uuid = uuid.uuid4()
if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
@@ -318,18 +324,92 @@ class ModelPatcher:
n.is_clip = self.is_clip
n.hook_mode = self.hook_mode
n.is_multigpu_base_clone = self.is_multigpu_base_clone
n.clone_base_uuid = self.clone_base_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
return n
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
comfy.model_management.unload_model_and_clones(self)
n = self.clone()
# set load device, if present
if new_load_device is not None:
n.load_device = new_load_device
# unlike for normal clone, backup dicts that shared same ref should not;
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
n.backup = copy.deepcopy(n.backup)
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
n.hook_backup = copy.deepcopy(n.hook_backup)
n.model = copy.deepcopy(n.model)
# multigpu clone should not have multigpu additional_models entry
n.remove_additional_models("multigpu")
# multigpu_clone all stored additional_models; make sure circular references are properly handled
if models_cache is None:
models_cache = {}
for key, model_list in n.additional_models.items():
for i in range(len(model_list)):
add_model = n.additional_models[key][i]
if add_model.clone_base_uuid not in models_cache:
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
callback(self, n)
return n
def match_multigpu_clones(self):
multigpu_models = self.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
new_multigpu_models = []
for mm in multigpu_models:
# clone main model, but bring over relevant props from existing multigpu clone
n = self.clone()
n.load_device = mm.load_device
n.backup = mm.backup
n.object_patches_backup = mm.object_patches_backup
n.hook_backup = mm.hook_backup
n.model = mm.model
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
n.remove_additional_models("multigpu")
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
# figure out which additional models are not present in multigpu clone
models_cache = {}
for mm_add_model in mm.get_additional_models():
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
remove_models_uuids = set(list(models_cache.keys()))
for key, model_list in orig_additional_models.items():
for orig_add_model in model_list:
if orig_add_model.clone_base_uuid not in models_cache:
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
existing_list = n.get_additional_models_with_key(key)
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
n.set_additional_models(key, existing_list)
if orig_add_model.clone_base_uuid in remove_models_uuids:
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
# remove duplicate additional models
for key, model_list in n.additional_models.items():
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
n.set_additional_models(key, new_model_list)
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
callback(self, n)
new_multigpu_models.append(n)
self.set_additional_models("multigpu", new_multigpu_models)
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def clone_has_same_weights(self, clone: 'ModelPatcher'):
if not self.is_clone(clone):
return False
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
if allow_multigpu:
if self.clone_base_uuid != clone.clone_base_uuid:
return False
else:
if not self.is_clone(clone):
return False
if self.current_hooks != clone.current_hooks:
return False
@@ -929,7 +1009,7 @@ class ModelPatcher:
return self.additional_models.get(key, [])
def get_additional_models(self):
all_models = []
all_models: list[ModelPatcher] = []
for models in self.additional_models.values():
all_models.extend(models)
return all_models
@@ -983,9 +1063,13 @@ class ModelPatcher:
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self)
def prepare_state(self, timestep):
def prepare_state(self, timestep, model_options, ignore_multigpu=False):
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep)
callback(self, timestep, model_options, ignore_multigpu)
if not ignore_multigpu and "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p.prepare_state(timestep, model_options, ignore_multigpu=True)
def restore_hook_patches(self):
if self.hook_patches_backup is not None:
@@ -998,12 +1082,18 @@ class ModelPatcher:
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
curr_t = t[0]
reset_current_hooks = False
multigpu_kf_changed_cache = None
transformer_options = model_options.get("transformer_options", {})
for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
# this will cause the weights to be recalculated when sampling
if changed:
# cache changed for multigpu usage
if "multigpu_clones" in model_options:
if multigpu_kf_changed_cache is None:
multigpu_kf_changed_cache = []
multigpu_kf_changed_cache.append(hook)
# reset current_hooks if contains hook that changed
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
@@ -1015,6 +1105,28 @@ class ModelPatcher:
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
if "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
if kf_changed_cache is None:
return
reset_current_hooks = False
# reset current_hooks if contains hook that changed
for hook in kf_changed_cache:
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
if current_hook == hook:
reset_current_hooks = True
break
for cached_group in list(self.cached_hook_patches.keys()):
if cached_group.contains(hook):
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None):

167
comfy/multigpu.py Normal file
View File

@@ -0,0 +1,167 @@
from __future__ import annotations
import torch
import logging
from collections import namedtuple
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.utils
import comfy.patcher_extension
import comfy.model_management
class GPUOptions:
def __init__(self, device_index: int, relative_speed: float):
self.device_index = device_index
self.relative_speed = relative_speed
def clone(self):
return GPUOptions(self.device_index, self.relative_speed)
def create_dict(self):
return {
"relative_speed": self.relative_speed
}
class GPUOptionsGroup:
def __init__(self):
self.options: dict[int, GPUOptions] = {}
def add(self, info: GPUOptions):
self.options[info.device_index] = info
def clone(self):
c = GPUOptionsGroup()
for opt in self.options.values():
c.add(opt)
return c
def register(self, model: ModelPatcher):
opts_dict = {}
# get devices that are valid for this model
devices: list[torch.device] = [model.load_device]
for extra_model in model.get_additional_models_with_key("multigpu"):
extra_model: ModelPatcher
devices.append(extra_model.load_device)
# create dictionary with actual device mapped to its GPUOptions
device_opts_list: list[GPUOptions] = []
for device in devices:
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
opts_dict[device] = device_opts.create_dict()
device_opts_list.append(device_opts)
# make relative_speed relative to 1.0
min_speed = min([x.relative_speed for x in device_opts_list])
for value in opts_dict.values():
value['relative_speed'] /= min_speed
model.model_options['multigpu_options'] = opts_dict
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
model = model.clone()
# check if multigpu is already prepared - get the load devices from them if possible to exclude
skip_devices = set()
multigpu_models = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
for mm in multigpu_models:
skip_devices.add(mm.load_device)
skip_devices = list(skip_devices)
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
limit_extra_devices = full_extra_devices[:max_gpus-1]
extra_devices = limit_extra_devices.copy()
# exclude skipped devices
for skip in skip_devices:
if skip in extra_devices:
extra_devices.remove(skip)
# create new deepclones
if len(extra_devices) > 0:
for device in extra_devices:
device_patcher = None
if reuse_loaded:
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
for lm in loaded_models:
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
device_patcher = lm.clone()
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
break
if device_patcher is None:
device_patcher = model.deepclone_multigpu(new_load_device=device)
device_patcher.is_multigpu_base_clone = True
multigpu_models = model.get_additional_models_with_key("multigpu")
multigpu_models.append(device_patcher)
model.set_additional_models("multigpu", multigpu_models)
model.match_multigpu_clones()
if gpu_options is None:
gpu_options = GPUOptionsGroup()
gpu_options.register(model)
else:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
# TODO: only keep model clones that don't go 'past' the intended max_gpu count
# multigpu_models = model.get_additional_models_with_key("multigpu")
# new_multigpu_models = []
# for m in multigpu_models:
# if m.load_device in limit_extra_devices:
# new_multigpu_models.append(m)
# model.set_additional_models("multigpu", new_multigpu_models)
# persist skip_devices for use in sampling code
# if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options:
# model.model_options["multigpu_skip_devices"] = skip_devices
return model
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
opts_dict = model_options['multigpu_options']
devices = list(model_options['multigpu_clones'].keys())
speed_per_device = []
work_per_device = []
# get sum of each device's relative_speed
total_speed = 0.0
for opts in opts_dict.values():
total_speed += opts['relative_speed']
# get relative work for each device;
# obtained by w = (W*r)/R
for device in devices:
relative_speed = opts_dict[device]['relative_speed']
relative_work = (total_work*relative_speed) / total_speed
speed_per_device.append(relative_speed)
work_per_device.append(relative_work)
# relative work must be expressed in whole numbers, but likely is a decimal;
# perform rounding while maintaining total sum equal to total work (sum of relative works)
work_per_device = round_preserved(work_per_device)
dict_work_per_device = {}
for device, relative_work in zip(devices, work_per_device):
dict_work_per_device[device] = relative_work
if not return_idle_time:
return LoadBalance(dict_work_per_device, None)
# divide relative work by relative speed to get estimated completion time of said work by each device;
# time here is relative and does not correspond to real-world units
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
# calculate relative time spent by the devices waiting on each other after their work is completed
idle_time = abs(min(completion_time) - max(completion_time))
# if need to compare work idle time, need to normalize to a common total work
if work_normalized:
idle_time *= (work_normalized/total_work)
return LoadBalance(dict_work_per_device, idle_time)
def round_preserved(values: list[float]):
'Round all values in a list, preserving the combined sum of values.'
# get floor of values; casting to int does it too
floored = [int(x) for x in values]
total_floored = sum(floored)
# get remainder to distribute
remainder = round(sum(values)) - total_floored
# pair values with fractional portions
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
# sort by fractional part in descending order
fractional.sort(key=lambda x: x[1], reverse=True)
# distribute the remainder
for i in range(remainder):
index = fractional[i][0]
floored[index] += 1
return floored

View File

@@ -3,6 +3,8 @@ from typing import Callable
class CallbacksMP:
ON_CLONE = "on_clone"
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
ON_LOAD = "on_load_after"
ON_DETACH = "on_detach_after"
ON_CLEANUP = "on_cleanup"

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
import torch
import uuid
import math
import collections
import comfy.model_management
import comfy.conds
import comfy.model_patcher
import comfy.utils
import comfy.hooks
import comfy.patcher_extension
@@ -106,6 +108,47 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) == 0:
return
extra_devices = [x.load_device for x in multigpu_models]
# handle controlnets
controlnets: set[ControlBase] = set()
for k in conds:
for kk in conds[k]:
if 'control' in kk:
controlnets.add(kk['control'])
if len(controlnets) > 0:
# first, unload all controlnet clones
for cnet in list(controlnets):
cnet_models = cnet.get_models()
for cm in cnet_models:
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
# next, make sure each controlnet has a deepclone for all relevant devices
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
for device in extra_devices:
if device not in curr_cnet.multigpu_clones:
curr_cnet.deepclone_multigpu(device, autoregister=True)
curr_cnet = curr_cnet.previous_controlnet
# since all device clones are now present, recreate the linked list for cloned cnets per device
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
prev_cnet = curr_cnet.previous_controlnet
for device in extra_devices:
device_cnet = curr_cnet.get_instance_for_device(device)
prev_device_cnet = None
if prev_cnet is not None:
prev_device_cnet = prev_cnet.get_instance_for_device(device)
device_cnet.set_previous_controlnet(prev_device_cnet)
curr_cnet = prev_cnet
# potentially handle gligen - since not widely used, ignored for now
def estimate_memory(model, noise_shape, conds):
cond_shapes = collections.defaultdict(list)
cond_shapes_min = {}
@@ -130,7 +173,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
return executor.execute(model, noise_shape, conds, model_options=model_options)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
real_model: BaseModel = None
model.match_multigpu_clones()
preprocess_multigpu_conds(conds, model, model_options)
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
@@ -149,7 +193,7 @@ def cleanup_models(conds, models):
cleanup_additional_models(set(control_cleanup))
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
'''
Registers hooks from conds.
'''
@@ -182,3 +226,18 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False)
return to_load_options
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
'''
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
'''
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
if len(multigpu_patchers) > 0:
multigpu_dict: dict[torch.device, ModelPatcher] = {}
multigpu_dict[model_patcher.load_device] = model_patcher
for x in multigpu_patchers:
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
multigpu_dict[x.load_device] = x
model_options["multigpu_clones"] = multigpu_dict
return multigpu_patchers

View File

@@ -1,4 +1,6 @@
from __future__ import annotations
import comfy.model_management
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
from typing import TYPE_CHECKING, Callable, NamedTuple
@@ -18,6 +20,7 @@ import comfy.patcher_extension
import comfy.hooks
import scipy.stats
import numpy
import threading
def add_area_dims(area, num_dims):
@@ -140,7 +143,7 @@ def can_concat_cond(c1, c2):
return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list):
def cond_cat(c_list, device=None):
temp = {}
for x in c_list:
for k in x:
@@ -152,6 +155,8 @@ def cond_cat(c_list):
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
if device is not None and hasattr(out[k], 'to'):
out[k] = out[k].to(device)
return out
@@ -205,7 +210,9 @@ def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Ten
)
return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
if 'multigpu_clones' in model_options:
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
out_conds = []
out_counts = []
# separate conds by matching hooks
@@ -237,7 +244,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep)
model.current_patcher.prepare_state(timestep, model_options)
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
@@ -345,6 +352,190 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
return out_conds
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
out_conds = []
out_counts = []
# separate conds by matching hooks
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
default_conds = []
has_default_conds = False
output_device = x_in.device
for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37)
cond = conds[i]
default_c = []
if cond is not None:
for x in cond:
if 'default' in x:
default_c.append(x)
has_default_conds = True
continue
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c)
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep, model_options)
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
total_conds = 0
for to_run in hooked_to_run.values():
total_conds += len(to_run)
conds_per_device = max(1, math.ceil(total_conds//len(devices)))
index_device = 0
current_device = devices[index_device]
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
while len(to_run) > 0:
current_device = devices[index_device % len(devices)]
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
# keep track of conds currently scheduled onto this device
batched_to_run_length = 0
for btr in batched_to_run:
batched_to_run_length += len(btr[1])
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
# make sure not over conds_per_device limit when creating temp batch
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(current_device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break
conds_to_batch = []
for x in to_batch:
conds_to_batch.append(to_run.pop(x))
batched_to_run_length += len(conds_to_batch)
batched_to_run.append((hooks, conds_to_batch))
if batched_to_run_length >= conds_per_device:
index_device += 1
thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond'])
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
model_current: BaseModel = model_options["multigpu_clones"][device].model
# run every hooked_to_run separately
with torch.no_grad():
for hooks, to_batch in batch_tuple:
input_x = []
mult = []
c = []
cond_or_uncond = []
uuids = []
area = []
control: ControlBase = None
patches = None
for x in to_batch:
o = x
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
uuids.append(p.uuid)
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x).to(device)
c = cond_cat(c, device=device)
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'],
copy_dict1=False)
if patches is not None:
# TODO: replace with merge_nested_dicts function
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
transformer_options["patches"] = cur_patches
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
transformer_options["sigmas"] = timestep
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
transformer_options["multigpu_thread_device"] = device
cast_transformer_options(transformer_options, device=device)
c['transformer_options'] = transformer_options
if control is not None:
device_control = control.get_instance_for_device(device)
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
else:
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
results: list[thread_result] = []
threads: list[threading.Thread] = []
for device, batch_tuple in device_batched_hooked_to_run.items():
new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results))
threads.append(new_thread)
new_thread.start()
for thread in threads:
thread.join()
for output, mult, area, batch_chunks, cond_or_uncond in results:
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
return out_conds
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
@@ -642,6 +833,8 @@ def pre_run_control(model, conds):
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function)
for device_cnet in x['control'].multigpu_clones.values():
device_cnet.pre_run(model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
@@ -884,7 +1077,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
to_load_options = model_options.get("to_load_options", None)
if to_load_options is None:
return
cast_transformer_options(to_load_options, device, dtype)
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
casts = []
if device is not None:
casts.append(device)
@@ -893,18 +1088,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# if nothing to apply, do nothing
if len(casts) == 0:
return
# try to call .to on patches
if "patches" in to_load_options:
patches = to_load_options["patches"]
if "patches" in transformer_options:
patches = transformer_options["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
for cast in casts:
patch_list[i] = patch_list[i].to(cast)
if "patches_replace" in to_load_options:
patches = to_load_options["patches_replace"]
if "patches_replace" in transformer_options:
patches = transformer_options["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
@@ -914,8 +1108,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
if wc_name in to_load_options:
wc: dict[str, list] = to_load_options[wc_name]
if wc_name in transformer_options:
wc: dict[str, list] = transformer_options[wc_name]
for wc_dict in wc.values():
for wc_list in wc_dict.values():
for i in range(len(wc_list)):
@@ -923,7 +1117,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
for cast in casts:
wc_list[i] = wc_list[i].to(cast)
class CFGGuider:
def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher
@@ -969,6 +1162,8 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
@@ -979,9 +1174,13 @@ class CFGGuider:
try:
self.model_patcher.pre_run()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
self.model_patcher.cleanup()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model

View File

@@ -1,855 +0,0 @@
from __future__ import annotations
from typing import Any, Literal
from enum import Enum
from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from comfy.comfy_types.node_typing import IO
class InputBehavior(str, Enum):
required = "required"
optional = "optional"
def is_class(obj):
'''
Returns True if is a class type.
Returns False if is a class instance.
'''
return isinstance(obj, type)
class NumberDisplay(str, Enum):
number = "number"
slider = "slider"
class IO_V3:
'''
Base class for V3 Inputs and Outputs.
'''
def __init__(self):
pass
def __init_subclass__(cls, io_type: IO | str, **kwargs):
cls.io_type = io_type
super().__init_subclass__(**kwargs)
class InputV3(IO_V3, io_type=None):
'''
Base class for a V3 Input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None):
super().__init__()
self.id = id
self.display_name = display_name
self.behavior = behavior
self.tooltip = tooltip
self.lazy = lazy
def as_dict_V1(self):
return prune_dict({
"display_name": self.display_name,
"tooltip": self.tooltip,
"lazy": self.lazy
})
def get_io_type_V1(self):
return self.io_type
class WidgetInputV3(InputV3, io_type=None):
'''
Base class for a V3 Input with widget.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: Any=None,
socketless: bool=None, widgetType: str=None):
super().__init__(id, display_name, behavior, tooltip, lazy)
self.default = default
self.socketless = socketless
self.widgetType = widgetType
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"default": self.default,
"socketless": self.socketless,
"widgetType": self.widgetType,
})
def CustomType(io_type: IO | str) -> type[IO_V3]:
name = f"{io_type}_IO_V3"
return type(name, (IO_V3,), {}, io_type=io_type)
def CustomInput(id: str, io_type: IO | str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None) -> InputV3:
'''
Defines input for 'io_type'. Can be used to stand in for non-core types.
'''
input_kwargs = {
"id": id,
"display_name": display_name,
"behavior": behavior,
"tooltip": tooltip,
"lazy": lazy,
}
return type(f"{io_type}Input", (InputV3,), {}, io_type=io_type)(**input_kwargs)
def CustomOutput(id: str, io_type: IO | str, display_name: str=None, tooltip: str=None) -> OutputV3:
'''
Defines output for 'io_type'. Can be used to stand in for non-core types.
'''
input_kwargs = {
"id": id,
"display_name": display_name,
"tooltip": tooltip,
}
return type(f"{io_type}Output", (OutputV3,), {}, io_type=io_type)(**input_kwargs)
class BooleanInput(WidgetInputV3, io_type=IO.BOOLEAN):
'''
Boolean input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: bool=None, label_on: str=None, label_off: str=None,
socketless: bool=None, widgetType: str=None):
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
self.label_on = label_on
self.label_off = label_off
self.default: bool
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"label_on": self.label_on,
"label_off": self.label_off,
})
class IntegerInput(WidgetInputV3, io_type=IO.INT):
'''
Integer input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
self.min = min
self.max = max
self.step = step
self.control_after_generate = control_after_generate
self.display_mode = display_mode
self.default: int
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"min": self.min,
"max": self.max,
"step": self.step,
"control_after_generate": self.control_after_generate,
"display": self.display_mode, # NOTE: in frontend, the parameter is called "display"
})
class FloatInput(WidgetInputV3, io_type=IO.FLOAT):
'''
Float input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
self.default = default
self.min = min
self.max = max
self.step = step
self.round = round
self.display_mode = display_mode
self.default: float
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"min": self.min,
"max": self.max,
"step": self.step,
"round": self.round,
"display": self.display_mode, # NOTE: in frontend, the parameter is called "display"
})
class StringInput(WidgetInputV3, io_type=IO.STRING):
'''
String input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
multiline=False, placeholder: str=None, default: int=None,
socketless: bool=None, widgetType: str=None):
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
self.multiline = multiline
self.placeholder = placeholder
self.default: str
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"multiline": self.multiline,
"placeholder": self.placeholder,
})
class ComboInput(WidgetInputV3, io_type=IO.COMBO):
'''Combo input (dropdown).'''
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: str=None, control_after_generate: bool=None,
socketless: bool=None, widgetType: str=None):
super().__init__(id, display_name, behavior, tooltip, lazy, default, socketless, widgetType)
self.multiselect = False
self.options = options
self.control_after_generate = control_after_generate
self.default: str
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"multiselect": self.multiselect,
"options": self.options,
"control_after_generate": self.control_after_generate,
})
class MultiselectComboWidget(ComboInput, io_type=IO.COMBO):
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
socketless: bool=None, widgetType: str=None):
super().__init__(id, options, display_name, behavior, tooltip, lazy, default, control_after_generate, socketless, widgetType)
self.multiselect = True
self.placeholder = placeholder
self.chip = chip
self.default: list[str]
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"multiselect": self.multiselect,
"placeholder": self.placeholder,
"chip": self.chip,
})
class ImageInput(InputV3, io_type=IO.IMAGE):
'''
Image input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
super().__init__(id, display_name, behavior, tooltip)
class MaskInput(InputV3, io_type=IO.MASK):
'''
Mask input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
super().__init__(id, display_name, behavior, tooltip)
class LatentInput(InputV3, io_type=IO.LATENT):
'''
Latent input.
'''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
super().__init__(id, display_name, behavior, tooltip)
class MultitypedInput(InputV3, io_type="COMFY_MULTITYPED_V3"):
'''
Input that permits more than one input type.
'''
def __init__(self, id: str, io_types: list[type[IO_V3] | InputV3 | IO |str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None,):
super().__init__(id, display_name, behavior, tooltip)
self._io_types = io_types
@property
def io_types(self) -> list[type[InputV3]]:
'''
Returns list of InputV3 class types permitted.
'''
io_types = []
for x in self._io_types:
if not is_class(x):
io_types.append(type(x))
else:
io_types.append(x)
return io_types
def get_io_type_V1(self):
return ",".join(x.io_type for x in self.io_types)
class OutputV3:
def __init__(self, id: str, display_name: str=None, tooltip: str=None,
is_output_list=False):
self.id = id
self.display_name = display_name
self.tooltip = tooltip
self.is_output_list = is_output_list
def __init_subclass__(cls, io_type, **kwargs):
cls.io_type = io_type
super().__init_subclass__(**kwargs)
class IntegerOutput(OutputV3, io_type=IO.INT):
pass
class FloatOutput(OutputV3, io_type=IO.FLOAT):
pass
class StringOutput(OutputV3, io_type=IO.STRING):
pass
# def __init__(self, id: str, display_name: str=None, tooltip: str=None):
# super().__init__(id, display_name, tooltip)
class ImageOutput(OutputV3, io_type=IO.IMAGE):
pass
class MaskOutput(OutputV3, io_type=IO.MASK):
pass
class LatentOutput(OutputV3, io_type=IO.LATENT):
pass
class DynamicInput(InputV3, io_type=None):
'''
Abstract class for dynamic input registration.
'''
def __init__(self, io_type: str, id: str, display_name: str=None):
super().__init__(io_type, id, display_name)
class DynamicOutput(OutputV3, io_type=None):
'''
Abstract class for dynamic output registration.
'''
def __init__(self, io_type: str, id: str, display_name: str=None):
super().__init__(io_type, id, display_name)
class AutoGrowDynamicInput(DynamicInput, io_type="COMFY_MULTIGROW_V3"):
'''
Dynamic Input that adds another template_input each time one is provided.
Additional inputs are forced to have 'InputBehavior.optional'.
'''
def __init__(self, id: str, template_input: InputV3, min: int=1, max: int=None):
super().__init__("AutoGrowDynamicInput", id)
self.template_input = template_input
if min is not None:
assert(min >= 1)
if max is not None:
assert(max >= 1)
self.min = min
self.max = max
class ComboDynamicInput(DynamicInput, io_type="COMFY_COMBODYNAMIC_V3"):
def __init__(self, id: str):
pass
AutoGrowDynamicInput(id="dynamic", template_input=ImageInput(id="image"))
class Hidden(str, Enum):
'''
Enumerator for requesting hidden variables in nodes.
'''
unique_id = "UNIQUE_ID"
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
prompt = "PROMPT"
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
extra_pnginfo = "EXTRA_PNGINFO"
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
dynprompt = "DYNPROMPT"
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
auth_token_comfy_org = "AUTH_TOKEN_COMFY_ORG"
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
api_key_comfy_org = "API_KEY_COMFY_ORG"
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
@dataclass
class NodeInfoV1:
input: dict=None
input_order: dict[str, list[str]]=None
output: list[str]=None
output_is_list: list[bool]=None
output_name: list[str]=None
output_tooltips: list[str]=None
name: str=None
display_name: str=None
description: str=None
python_module: Any=None
category: str=None
output_node: bool=None
deprecated: bool=None
experimental: bool=None
api_node: bool=None
def as_pruned_dict(dataclass_obj):
'''Return dict of dataclass object with pruned None values.'''
return prune_dict(asdict(dataclass_obj))
def prune_dict(d: dict):
return {k: v for k,v in d.items() if v is not None}
@dataclass
class SchemaV3:
"""Definition of V3 node properties."""
node_id: str
"""ID of node - should be globally unique. If this is a custom node, add a prefix or postfix to avoid name clashes."""
display_name: str = None
"""Display name of node."""
category: str = "sd"
"""The category of the node, as per the "Add Node" menu."""
inputs: list[InputV3]=None
outputs: list[OutputV3]=None
hidden: list[Hidden]=None
description: str=""
"""Node description, shown as a tooltip when hovering over the node."""
is_input_list: bool = False
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
From the docs:
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
"""
is_output_node: bool=False
"""Flags this node as an output node, causing any inputs it requires to be executed.
If a node is not connected to any output nodes, that node will not be executed. Usage::
OUTPUT_NODE = True
From the docs:
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
"""
is_deprecated: bool=False
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
is_experimental: bool=False
"""Flags a node as experimental, informing users that it may change or not work as expected."""
is_api_node: bool=False
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
# class SchemaV3Class:
# def __init__(self,
# node_id: str,
# node_name: str,
# category: str,
# inputs: list[InputV3],
# outputs: list[OutputV3]=None,
# hidden: list[Hidden]=None,
# description: str="",
# is_input_list: bool = False,
# is_output_node: bool=False,
# is_deprecated: bool=False,
# is_experimental: bool=False,
# is_api_node: bool=False,
# ):
# self.node_id = node_id
# """ID of node - should be globally unique. If this is a custom node, add a prefix or postfix to avoid name clashes."""
# self.node_name = node_name
# """Display name of node."""
# self.category = category
# """The category of the node, as per the "Add Node" menu."""
# self.inputs = inputs
# self.outputs = outputs
# self.hidden = hidden
# self.description = description
# """Node description, shown as a tooltip when hovering over the node."""
# self.is_input_list = is_input_list
# """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
# All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
# From the docs:
# A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
# Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
# """
# self.is_output_node = is_output_node
# """Flags this node as an output node, causing any inputs it requires to be executed.
# If a node is not connected to any output nodes, that node will not be executed. Usage::
# OUTPUT_NODE = True
# From the docs:
# By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
# Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
# """
# self.is_deprecated = is_deprecated
# """Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
# self.is_experimental = is_experimental
# """Flags a node as experimental, informing users that it may change or not work as expected."""
# self.is_api_node = is_api_node
# """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
class classproperty(object):
def __init__(self, f):
self.f = f
def __get__(self, obj, owner):
return self.f(owner)
class ComfyNodeV3(ABC):
"""Common base class for all V3 nodes."""
RELATIVE_PYTHON_MODULE = None
#############################################
# V1 Backwards Compatibility code
#--------------------------------------------
_DESCRIPTION = None
@classproperty
def DESCRIPTION(cls):
if cls._DESCRIPTION is None:
cls.GET_SCHEMA()
return cls._DESCRIPTION
_CATEGORY = None
@classproperty
def CATEGORY(cls):
if cls._CATEGORY is None:
cls.GET_SCHEMA()
return cls._CATEGORY
_EXPERIMENTAL = None
@classproperty
def EXPERIMENTAL(cls):
if cls._EXPERIMENTAL is None:
cls.GET_SCHEMA()
return cls._EXPERIMENTAL
_DEPRECATED = None
@classproperty
def DEPRECATED(cls):
if cls._DEPRECATED is None:
cls.GET_SCHEMA()
return cls._DEPRECATED
_API_NODE = None
@classproperty
def API_NODE(cls):
if cls._API_NODE is None:
cls.GET_SCHEMA()
return cls._API_NODE
_OUTPUT_NODE = None
@classproperty
def OUTPUT_NODE(cls):
if cls._OUTPUT_NODE is None:
cls.GET_SCHEMA()
return cls._OUTPUT_NODE
_INPUT_IS_LIST = None
@classproperty
def INPUT_IS_LIST(cls):
if cls._INPUT_IS_LIST is None:
cls.GET_SCHEMA()
return cls._INPUT_IS_LIST
_OUTPUT_IS_LIST = None
@classproperty
def OUTPUT_IS_LIST(cls):
if cls._OUTPUT_IS_LIST is None:
cls.GET_SCHEMA()
return cls._OUTPUT_IS_LIST
_RETURN_TYPES = None
@classproperty
def RETURN_TYPES(cls):
if cls._RETURN_TYPES is None:
cls.GET_SCHEMA()
return cls._RETURN_TYPES
_RETURN_NAMES = None
@classproperty
def RETURN_NAMES(cls):
if cls._RETURN_NAMES is None:
cls.GET_SCHEMA()
return cls._RETURN_NAMES
_OUTPUT_TOOLTIPS = None
@classproperty
def OUTPUT_TOOLTIPS(cls):
if cls._OUTPUT_TOOLTIPS is None:
cls.GET_SCHEMA()
return cls._OUTPUT_TOOLTIPS
FUNCTION = "execute"
@classmethod
def INPUT_TYPES(cls) -> dict[str, dict]:
schema = cls.DEFINE_SCHEMA()
# for V1, make inputs be a dict with potential keys {required, optional, hidden}
input = {
"required": {}
}
if schema.inputs:
for i in schema.inputs:
input.setdefault(i.behavior.value, {})[i.id] = (i.get_io_type_V1(), i.as_dict_V1())
if schema.hidden:
for hidden in schema.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
return input
@classmethod
def GET_SCHEMA(cls) -> SchemaV3:
schema = cls.DEFINE_SCHEMA()
if cls._DESCRIPTION is None:
cls._DESCRIPTION = schema.description
if cls._CATEGORY is None:
cls._CATEGORY = schema.category
if cls._EXPERIMENTAL is None:
cls._EXPERIMENTAL = schema.is_experimental
if cls._DEPRECATED is None:
cls._DEPRECATED = schema.is_deprecated
if cls._API_NODE is None:
cls._API_NODE = schema.is_api_node
if cls._OUTPUT_NODE is None:
cls._OUTPUT_NODE = schema.is_output_node
if cls._INPUT_IS_LIST is None:
cls._INPUT_IS_LIST = schema.is_input_list
if cls._RETURN_TYPES is None:
output = []
output_name = []
output_is_list = []
output_tooltips = []
if schema.outputs:
for o in schema.outputs:
output.append(o.io_type)
output_name.append(o.display_name if o.display_name else o.io_type)
output_is_list.append(o.is_output_list)
output_tooltips.append(o.tooltip if o.tooltip else None)
cls._RETURN_TYPES = output
cls._RETURN_NAMES = output_name
cls._OUTPUT_IS_LIST = output_is_list
cls._OUTPUT_TOOLTIPS = output_tooltips
return schema
@classmethod
def GET_NODE_INFO_V1(cls) -> dict[str, Any]:
schema = cls.GET_SCHEMA()
# get V1 inputs
input = cls.INPUT_TYPES()
# create separate lists from output fields
output = []
output_is_list = []
output_name = []
output_tooltips = []
if schema.outputs:
for o in schema.outputs:
output.append(o.io_type)
output_is_list.append(o.is_output_list)
output_name.append(o.display_name if o.display_name else o.io_type)
output_tooltips.append(o.tooltip if o.tooltip else None)
info = NodeInfoV1(
input=input,
input_order={key: list(value.keys()) for (key, value) in input.items()},
output=output,
output_is_list=output_is_list,
output_name=output_name,
output_tooltips=output_tooltips,
name=schema.node_id,
display_name=schema.display_name,
category=schema.category,
description=schema.description,
output_node=schema.is_output_node,
deprecated=schema.is_deprecated,
experimental=schema.is_experimental,
api_node=schema.is_api_node,
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes")
)
return asdict(info)
#--------------------------------------------
#############################################
@classmethod
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
schema = cls.GET_SCHEMA()
# TODO: finish
return None
@classmethod
@abstractmethod
def DEFINE_SCHEMA(cls) -> SchemaV3:
"""
Override this function with one that returns a SchemaV3 instance.
"""
return None
DEFINE_SCHEMA = None
def __init__(self):
if self.DEFINE_SCHEMA is None:
raise Exception("No DEFINE_SCHEMA function was defined for this node.")
@abstractmethod
def execute(self, **kwargs) -> NodeOutput:
pass
# class ReturnedInputs:
# def __init__(self):
# pass
# class ReturnedOutputs:
# def __init__(self):
# pass
class NodeOutput:
'''
Standardized output of a node; can pass in any number of args and/or a UIOutput into 'ui' kwarg.
'''
def __init__(self, *args: Any, ui: UIOutput | dict=None, expand: dict=None, block_execution: str=None, **kwargs):
self.args = args
self.ui = ui
self.expand = expand
self.block_execution = block_execution
@property
def result(self):
return self.args if len(self.args) > 0 else None
class SavedResult:
def __init__(self, filename: str, subfolder: str, type: Literal["input", "output", "temp"]):
self.filename = filename
self.subfolder = subfolder
self.type = type
def as_dict(self):
return {
"filename": self.filename,
"subfolder": self.subfolder,
"type": self.type
}
class UIOutput(ABC):
def __init__(self):
pass
@abstractmethod
def as_dict(self) -> dict:
... # TODO: finish
class UIImages(UIOutput):
def __init__(self, values: list[SavedResult | dict], animated=False, **kwargs):
self.values = values
self.animated = animated
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"images": values,
"animated": (self.animated,)
}
class UILatents(UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"latents": values,
}
class UIAudio(UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"audio": values,
}
class UI3D(UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"3d": values,
}
class UIText(UIOutput):
def __init__(self, value: str, **kwargs):
self.value = value
def as_dict(self):
return {"text": (self.value,)}
class TestNode(ComfyNodeV3):
SCHEMA = SchemaV3(
node_id="TestNode_v3",
display_name="Test Node (V3)",
category="v3_test",
inputs=[IntegerInput("my_int"),
#AutoGrowDynamicInput("growing", ImageInput),
MaskInput("thing"),
],
outputs=[ImageOutput("image_output")],
hidden=[Hidden.api_key_comfy_org, Hidden.auth_token_comfy_org, Hidden.unique_id]
)
# @classmethod
# def GET_SCHEMA(cls):
# return cls.SCHEMA
@classmethod
def DEFINE_SCHEMA(cls):
return cls.SCHEMA
def execute(**kwargs):
pass
if __name__ == "__main__":
print("hello there")
inputs: list[InputV3] = [
IntegerInput("my_int"),
CustomInput("xyz", "XYZ"),
CustomInput("model1", "MODEL_M"),
ImageInput("my_image"),
FloatInput("my_float"),
MultitypedInput("my_inputs", [CustomType("MODEL_M"), CustomType("XYZ")]),
]
outputs: list[OutputV3] = [
ImageOutput("image"),
CustomOutput("xyz", "XYZ")
]
for c in inputs:
if isinstance(c, MultitypedInput):
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}, {[x.io_type for x in c.io_types]}")
print(c.get_io_type_V1())
else:
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
for c in outputs:
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
zz = TestNode()
print(zz.GET_NODE_INFO_V1())
# aa = NodeInfoV1()
# print(asdict(aa))
# print(as_pruned_dict(aa))

View File

@@ -0,0 +1,86 @@
from __future__ import annotations
from inspect import cleandoc
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.multigpu
class MultiGPUWorkUnitsNode:
"""
Prepares model to have sampling accelerated via splitting work units.
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
Other than those exceptions, this node can be placed in any order.
"""
NodeId = "MultiGPU_WorkUnits"
NodeName = "MultiGPU Work Units"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}),
},
"optional": {
"gpu_options": ("GPU_OPTIONS",)
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "init_multigpu"
CATEGORY = "advanced/multigpu"
DESCRIPTION = cleandoc(__doc__)
def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True)
return (model,)
class MultiGPUOptionsNode:
"""
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
"""
NodeId = "MultiGPU_Options"
NodeName = "MultiGPU Options"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"device_index": ("INT", {"default": 0, "min": 0, "max": 64}),
"relative_speed": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.01})
},
"optional": {
"gpu_options": ("GPU_OPTIONS",)
}
}
RETURN_TYPES = ("GPU_OPTIONS",)
FUNCTION = "create_gpu_options"
CATEGORY = "advanced/multigpu"
DESCRIPTION = cleandoc(__doc__)
def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
if not gpu_options:
gpu_options = comfy.multigpu.GPUOptionsGroup()
gpu_options.clone()
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
gpu_options.add(opt)
return (gpu_options,)
node_list = [
MultiGPUWorkUnitsNode,
MultiGPUOptionsNode
]
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
for node in node_list:
NODE_CLASS_MAPPINGS[node.NodeId] = node
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName

View File

@@ -1,67 +0,0 @@
import torch
from comfy_api.v3.io import (
ComfyNodeV3, SchemaV3, CustomType, CustomInput, CustomOutput, InputBehavior, NumberDisplay,
IntegerInput, MaskInput, ImageInput, ComboDynamicInput, NodeOutput,
)
class V3TestNode(ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return SchemaV3(
node_id="V3TestNode1",
display_name="V3 Test Node (1djekjd)",
description="This is a funky V3 node test.",
category="v3 nodes",
inputs=[
IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display_mode=NumberDisplay.slider),
MaskInput("mask", behavior=InputBehavior.optional),
ImageInput("image", display_name="new_image"),
# IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display=NumberDisplay.slider, ),
# ComboDynamicInput("mask", behavior=InputBehavior.optional),
# IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display=NumberDisplay.slider,
# dependent_inputs=[ComboDynamicInput("mask", behavior=InputBehavior.optional)],
# dependent_values=[lambda my_value: IO.STRING if my_value < 5 else IO.NUMBER],
# ),
# ["option1", "option2". "option3"]
# ComboDynamicInput["sdfgjhl", [ComboDynamicOptions("option1", [IntegerInput("some_int", display_name="new_name", min=0, tooltip="My tooltip 😎", display=NumberDisplay.slider, ImageInput(), MaskInput(), String()]),
# CombyDynamicOptons("option2", [])
# ]]
],
is_output_node=True,
)
def execute(self, some_int: int, image: torch.Tensor, mask: torch.Tensor=None, **kwargs):
a = NodeOutput(1)
aa = NodeOutput(1, "hellothere")
ab = NodeOutput(1, "hellothere", ui={"lol": "jk"})
b = NodeOutput()
c = NodeOutput(ui={"lol": "jk"})
return NodeOutput()
return NodeOutput(1)
return NodeOutput(1, block_execution="Kill yourself")
return ()
NODES_LIST: list[ComfyNodeV3] = [
V3TestNode,
]
# NODE_CLASS_MAPPINGS = {}
# NODE_DISPLAY_NAME_MAPPINGS = {}
# for node in NODES_LIST:
# schema = node.GET_SCHEMA()
# NODE_CLASS_MAPPINGS[schema.node_id] = node
# if schema.display_name:
# NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name

View File

@@ -17,7 +17,6 @@ from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt,
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.validation import validate_node_input
from comfy_api.v3.io import NodeOutput
class ExecutionResult(Enum):
SUCCESS = 0
@@ -243,22 +242,6 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
result = tuple([result] * len(obj.RETURN_TYPES))
results.append(result)
subgraph_results.append((None, result))
elif isinstance(r, NodeOutput):
if r.ui is not None:
uis.append(r.ui.as_dict())
if r.expand is not None:
has_subgraph = True
new_graph = r.expand
result = r.result
if r.block_execution is not None:
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
subgraph_results.append((new_graph, result))
elif r.result is not None:
result = r.result
if r.block_execution is not None:
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
results.append(result)
subgraph_results.append((None, result))
else:
if isinstance(r, ExecutionBlocker):
r = tuple([r] * len(obj.RETURN_TYPES))

View File

@@ -26,7 +26,6 @@ import comfy.sd
import comfy.utils
import comfy.controlnet
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
from comfy_api.v3.io import ComfyNodeV3
import comfy.clip_vision
@@ -2130,7 +2129,6 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
if os.path.isdir(web_dir):
EXTENSION_WEB_DIRS[module_name] = web_dir
# V1 node definition
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
if name not in ignore:
@@ -2139,19 +2137,8 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True
# V3 node definition
elif getattr(module, "NODES_LIST", None) is not None:
for node_cls in module.NODES_LIST:
node_cls: ComfyNodeV3
schema = node_cls.GET_SCHEMA()
if schema.node_id not in ignore:
NODE_CLASS_MAPPINGS[schema.node_id] = node_cls
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
if schema.display_name is not None:
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
return True
else:
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).")
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
return False
except Exception as e:
logging.warning(traceback.format_exc())
@@ -2254,6 +2241,7 @@ def init_builtin_extra_nodes():
"nodes_mahiro.py",
"nodes_lt.py",
"nodes_hooks.py",
"nodes_multigpu.py",
"nodes_load_3d.py",
"nodes_cosmos.py",
"nodes_video.py",
@@ -2271,7 +2259,6 @@ def init_builtin_extra_nodes():
"nodes_ace.py",
"nodes_string.py",
"nodes_camera_trajectory.py",
"nodes_v3_test.py",
]
import_failed = []

View File

@@ -29,7 +29,6 @@ import comfy.model_management
import node_helpers
from comfyui_version import __version__
from app.frontend_management import FrontendManager
from comfy_api.v3.io import ComfyNodeV3
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
@@ -556,8 +555,6 @@ class PromptServer():
def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
if isinstance(obj_class, ComfyNodeV3):
return obj_class.GET_NODE_INFO_V1()
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}