Compare commits
120 Commits
v0.3.24
...
worksplit-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e5396e98d8 | ||
|
|
4879b47648 | ||
|
|
3b19fc76e3 | ||
|
|
5ccec33c22 | ||
|
|
219d3cd0d0 | ||
|
|
50614f1b79 | ||
|
|
6dc7b0bfe3 | ||
|
|
e8e990d6b8 | ||
|
|
2e24a15905 | ||
|
|
fd5297131f | ||
|
|
c4ba399475 | ||
|
|
55a1b09ddc | ||
|
|
3c3988df45 | ||
|
|
7ebd8087ff | ||
|
|
c624c29d66 | ||
|
|
a2448fc527 | ||
|
|
6a0daa79b6 | ||
|
|
9c98c6358b | ||
|
|
7aceb9f91c | ||
|
|
cc928a786d | ||
|
|
35504e2f93 | ||
|
|
299436cfed | ||
|
|
52e566d2bc | ||
|
|
9b6cd9b874 | ||
|
|
3fc688aebd | ||
|
|
f4411250f3 | ||
|
|
d2a0fb6bb0 | ||
|
|
01015bff16 | ||
|
|
2330754b0e | ||
|
|
bc219a6487 | ||
|
|
94689766ad | ||
|
|
cfbe4b49ca | ||
|
|
ca8efab79f | ||
|
|
65ea778a5e | ||
|
|
db9f2a34fc | ||
|
|
7946049794 | ||
|
|
6f6349b6a7 | ||
|
|
1f138dd382 | ||
|
|
b779349b55 | ||
|
|
35e2dcf5d7 | ||
|
|
67c7184b74 | ||
|
|
6f8e766509 | ||
|
|
e1da98a14a | ||
|
|
a73410aafa | ||
|
|
9aac21f894 | ||
|
|
528d1b3563 | ||
|
|
2bc4b5968f | ||
|
|
6e144b98c4 | ||
|
|
7395b0c0d1 | ||
|
|
0952569493 | ||
|
|
29832b3b61 | ||
|
|
be4e760648 | ||
|
|
c3d9cc4592 | ||
|
|
84cc9cb528 | ||
|
|
ebbb920163 | ||
|
|
d60fe0af4a | ||
|
|
5dbd250965 | ||
|
|
4ab1875283 | ||
|
|
11b1f27cb1 | ||
|
|
70e15fd743 | ||
|
|
e1474150de | ||
|
|
e62d72e8ca | ||
|
|
1650cda030 | ||
|
|
6dca17bd2d | ||
|
|
5080105c23 | ||
|
|
093914a247 | ||
|
|
605893d3cf | ||
|
|
048f4f0b3a | ||
|
|
d2504fb701 | ||
|
|
b03763bca6 | ||
|
|
476aa79b64 | ||
|
|
441cfd1a7a | ||
|
|
99a5c1068a | ||
|
|
02747cde7d | ||
|
|
0b3233b4e2 | ||
|
|
eda866bf51 | ||
|
|
e3298b84de | ||
|
|
c7feef9060 | ||
|
|
51af7fa1b4 | ||
|
|
46969c380a | ||
|
|
5db4277449 | ||
|
|
02a4d0ad7d | ||
|
|
ef137ac0b6 | ||
|
|
328d4f16a9 | ||
|
|
bdbcb85b8d | ||
|
|
6c9e94bae7 | ||
|
|
bfce723311 | ||
|
|
31f5458938 | ||
|
|
2145a202eb | ||
|
|
25818dc848 | ||
|
|
198953cd08 | ||
|
|
ec16ee2f39 | ||
|
|
d5088072fb | ||
|
|
8d4b50158e | ||
|
|
e88c6c03ff | ||
|
|
d3cf2b7b24 | ||
|
|
7448f02b7c | ||
|
|
871258aa72 | ||
|
|
66838ebd39 | ||
|
|
7333281698 | ||
|
|
3cd4c5cb0a | ||
|
|
11c6d56037 | ||
|
|
216fea15ee | ||
|
|
58bf8815c8 | ||
|
|
1b38f5bf57 | ||
|
|
2724ac4a60 | ||
|
|
f48f90e471 | ||
|
|
6463c39ce0 | ||
|
|
0a7e2ae787 | ||
|
|
03a97b604a | ||
|
|
4446c86052 | ||
|
|
8270ff312f | ||
|
|
db2d7ad9ba | ||
|
|
6620d86318 | ||
|
|
111fd0cadf | ||
|
|
776aa734e1 | ||
|
|
5a2ad032cb | ||
|
|
d44295ef71 | ||
|
|
bf21be066f | ||
|
|
72bbf49349 |
@@ -0,0 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||
pause
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "126"
|
||||
default: "128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "1"
|
||||
default: "2"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -34,7 +34,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
fetch-depth: 30
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -74,7 +74,7 @@ jobs:
|
||||
pause" > ./update/update_comfyui_and_python_dependencies.bat
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
|
||||
|
||||
cd ComfyUI_windows_portable_nightly_pytorch
|
||||
|
||||
@@ -19,5 +19,6 @@
|
||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
|
||||
# Extra nodes
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||
|
||||
@@ -215,9 +215,9 @@ Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||
This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements.
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
|
||||
@@ -11,20 +11,43 @@ from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict, Optional
|
||||
from importlib.metadata import version
|
||||
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
import app.logger
|
||||
|
||||
# The path to the requirements.txt file
|
||||
req_path = Path(__file__).parents[1] / "requirements.txt"
|
||||
|
||||
def frontend_install_warning_message():
|
||||
"""The warning message to display when the frontend version is not up to date."""
|
||||
|
||||
extra = ""
|
||||
if sys.flags.no_user_site:
|
||||
extra = "-s "
|
||||
return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem"
|
||||
|
||||
|
||||
try:
|
||||
import comfyui_frontend_package
|
||||
except ImportError:
|
||||
# TODO: Remove the check after roll out of 0.3.16
|
||||
req_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))
|
||||
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. Please install the updated requirements.txt file by running:\n{sys.executable} -m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n")
|
||||
exit(-1)
|
||||
def check_frontend_version():
|
||||
"""Check if the frontend version is up to date."""
|
||||
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
return tuple(map(int, version.split(".")))
|
||||
|
||||
try:
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
frontend_version = parse_version(frontend_version_str)
|
||||
with open(req_path, "r", encoding="utf-8") as f:
|
||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||
if frontend_version < required_frontend:
|
||||
app.logger.log_startup_warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message()))
|
||||
else:
|
||||
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check frontend version: {e}")
|
||||
|
||||
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
@@ -121,9 +144,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
|
||||
|
||||
class FrontendManager:
|
||||
DEFAULT_FRONTEND_PATH = str(importlib.resources.files(comfyui_frontend_package) / "static")
|
||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
try:
|
||||
import comfyui_frontend_package
|
||||
return str(importlib.resources.files(comfyui_frontend_package) / "static")
|
||||
except ImportError:
|
||||
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n")
|
||||
sys.exit(-1)
|
||||
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
@@ -160,7 +191,8 @@ class FrontendManager:
|
||||
main error source might be request timeout or invalid URL.
|
||||
"""
|
||||
if version_string == DEFAULT_VERSION_STRING:
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
|
||||
@@ -213,4 +245,5 @@ class FrontendManager:
|
||||
except Exception as e:
|
||||
logging.error("Failed to initialize frontend: %s", e)
|
||||
logging.info("Falling back to the default frontend.")
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
|
||||
@@ -82,3 +82,17 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
|
||||
logger.addHandler(stdout_handler)
|
||||
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
STARTUP_WARNINGS = []
|
||||
|
||||
|
||||
def log_startup_warning(msg):
|
||||
logging.warning(msg)
|
||||
STARTUP_WARNINGS.append(msg)
|
||||
|
||||
|
||||
def print_startup_warnings():
|
||||
for s in STARTUP_WARNINGS:
|
||||
logging.warning(s)
|
||||
STARTUP_WARNINGS.clear()
|
||||
|
||||
@@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use.")
|
||||
cm_group = parser.add_mutually_exclusive_group()
|
||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
||||
@@ -106,6 +106,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
|
||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
||||
|
||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import comfy.model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.clip_model
|
||||
import comfy.image_encoders.dino2
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
@@ -34,6 +35,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
IMAGE_ENCODERS = {
|
||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
}
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
with open(json_config) as f:
|
||||
@@ -42,10 +49,11 @@ class ClipVisionModel():
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
@@ -111,6 +119,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
elif "embeddings.patch_embeddings.projection.weight" in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Literal, TypedDict
|
||||
from typing_extensions import NotRequired
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
@@ -26,6 +27,7 @@ class IO(StrEnum):
|
||||
BOOLEAN = "BOOLEAN"
|
||||
INT = "INT"
|
||||
FLOAT = "FLOAT"
|
||||
COMBO = "COMBO"
|
||||
CONDITIONING = "CONDITIONING"
|
||||
SAMPLER = "SAMPLER"
|
||||
SIGMAS = "SIGMAS"
|
||||
@@ -66,6 +68,7 @@ class IO(StrEnum):
|
||||
b = frozenset(value.split(","))
|
||||
return not (b.issubset(a) or a.issubset(b))
|
||||
|
||||
|
||||
class RemoteInputOptions(TypedDict):
|
||||
route: str
|
||||
"""The route to the remote source."""
|
||||
@@ -80,6 +83,14 @@ class RemoteInputOptions(TypedDict):
|
||||
refresh: int
|
||||
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
|
||||
|
||||
|
||||
class MultiSelectOptions(TypedDict):
|
||||
placeholder: NotRequired[str]
|
||||
"""The placeholder text to display in the multi-select widget when no items are selected."""
|
||||
chip: NotRequired[bool]
|
||||
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
|
||||
|
||||
|
||||
class InputTypeOptions(TypedDict):
|
||||
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
||||
|
||||
@@ -114,7 +125,7 @@ class InputTypeOptions(TypedDict):
|
||||
# default: bool
|
||||
label_on: str
|
||||
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
||||
label_on: str
|
||||
label_off: str
|
||||
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
||||
# class InputTypeString(InputTypeOptions):
|
||||
# default: str
|
||||
@@ -133,9 +144,22 @@ class InputTypeOptions(TypedDict):
|
||||
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
||||
"""
|
||||
remote: RemoteInputOptions
|
||||
"""Specifies the configuration for a remote input."""
|
||||
"""Specifies the configuration for a remote input.
|
||||
Available after ComfyUI frontend v1.9.7
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
|
||||
control_after_generate: bool
|
||||
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
|
||||
options: NotRequired[list[str | int | float]]
|
||||
"""COMBO type only. Specifies the selectable options for the combo widget.
|
||||
Prefer:
|
||||
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
|
||||
Over:
|
||||
[["Option 1", "Option 2", "Option 3"]]
|
||||
"""
|
||||
multi_select: NotRequired[MultiSelectOptions]
|
||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||
Available after ComfyUI frontend v1.13.4
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
|
||||
@@ -15,13 +15,14 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from enum import Enum
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_detection
|
||||
@@ -36,7 +37,7 @@ import comfy.cldm.mmdit
|
||||
import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Union
|
||||
if TYPE_CHECKING:
|
||||
from comfy.hooks import HookGroup
|
||||
|
||||
@@ -63,6 +64,18 @@ class StrengthType(Enum):
|
||||
CONSTANT = 1
|
||||
LINEAR_UP = 2
|
||||
|
||||
class ControlIsolation:
|
||||
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
|
||||
def __init__(self, control: ControlBase):
|
||||
self.control = control
|
||||
self.orig_previous_controlnet = control.previous_controlnet
|
||||
|
||||
def __enter__(self):
|
||||
self.control.previous_controlnet = None
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.control.previous_controlnet = self.orig_previous_controlnet
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self):
|
||||
self.cond_hint_original = None
|
||||
@@ -76,7 +89,7 @@ class ControlBase:
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
self.previous_controlnet = None
|
||||
self.previous_controlnet: Union[ControlBase, None] = None
|
||||
self.extra_conds = []
|
||||
self.strength_type = StrengthType.CONSTANT
|
||||
self.concat_mask = False
|
||||
@@ -84,6 +97,7 @@ class ControlBase:
|
||||
self.extra_concat = None
|
||||
self.extra_hooks: HookGroup = None
|
||||
self.preprocess_image = lambda a: a
|
||||
self.multigpu_clones: dict[torch.device, ControlBase] = {}
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||
self.cond_hint_original = cond_hint
|
||||
@@ -110,17 +124,38 @@ class ControlBase:
|
||||
def cleanup(self):
|
||||
if self.previous_controlnet is not None:
|
||||
self.previous_controlnet.cleanup()
|
||||
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
with ControlIsolation(device_cnet):
|
||||
device_cnet.cleanup()
|
||||
self.cond_hint = None
|
||||
self.extra_concat = None
|
||||
self.timestep_range = None
|
||||
|
||||
def get_models(self):
|
||||
out = []
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
out += device_cnet.get_models_only_self()
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def get_models_only_self(self):
|
||||
'Calls get_models, but temporarily sets previous_controlnet to None.'
|
||||
with ControlIsolation(self):
|
||||
return self.get_models()
|
||||
|
||||
def get_instance_for_device(self, device):
|
||||
'Returns instance of this Control object intended for selected device.'
|
||||
return self.multigpu_clones.get(device, self)
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
'''
|
||||
Create deep clone of Control object where model(s) is set to other devices.
|
||||
|
||||
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
|
||||
'''
|
||||
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
@@ -129,7 +164,7 @@ class ControlBase:
|
||||
out += self.previous_controlnet.get_extra_hooks()
|
||||
return out
|
||||
|
||||
def copy_to(self, c):
|
||||
def copy_to(self, c: ControlBase):
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
c.strength = self.strength
|
||||
c.timestep_percent_range = self.timestep_percent_range
|
||||
@@ -280,6 +315,14 @@ class ControlNet(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.control_model = copy.deepcopy(c.control_model)
|
||||
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def get_models(self):
|
||||
out = super().get_models()
|
||||
out.append(self.control_model_wrapped)
|
||||
@@ -804,6 +847,14 @@ class T2IAdapter(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.t2i_model = copy.deepcopy(c.t2i_model)
|
||||
c.device = load_device
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
compression_ratio = 8
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
|
||||
141
comfy/image_encoders/dino2.py
Normal file
141
comfy/image_encoders/dino2.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import torch
|
||||
from comfy.text_encoders.bert import BertAttention
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
|
||||
class Dino2AttentionOutput(torch.nn.Module):
|
||||
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dense(x)
|
||||
|
||||
|
||||
class Dino2AttentionBlock(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
|
||||
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
|
||||
def forward(self, x, mask, optimized_attention):
|
||||
return self.output(self.attention(x, mask, optimized_attention))
|
||||
|
||||
|
||||
class LayerScale(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x):
|
||||
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||
|
||||
|
||||
class SwiGLUFFN(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
in_features = out_features = dim
|
||||
hidden_features = int(dim * 4)
|
||||
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
|
||||
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
|
||||
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.weights_in(x)
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
x = torch.nn.functional.silu(x1) * x2
|
||||
return self.weights_out(x)
|
||||
|
||||
|
||||
class Dino2Block(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, optimized_attention):
|
||||
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
|
||||
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
||||
|
||||
def forward(self, x, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layer) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layer):
|
||||
x = l(x, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class Dino2PatchEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.projection = operations.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
return self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||
|
||||
|
||||
class Dino2Embeddings(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
patch_size = 14
|
||||
image_size = 518
|
||||
|
||||
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
||||
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
||||
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, pixel_values):
|
||||
x = self.patch_embeddings(pixel_values)
|
||||
# TODO: mask_token?
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||
return x
|
||||
|
||||
|
||||
class Dinov2Model(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
|
||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
x = self.embeddings(pixel_values)
|
||||
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
||||
x = self.layernorm(x)
|
||||
pooled_output = x[:, 0, :]
|
||||
return x, i, pooled_output, None
|
||||
21
comfy/image_encoders/dino2_giant.json
Normal file
21
comfy/image_encoders/dino2_giant.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"drop_path_rate": 0.0,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"hidden_size": 1536,
|
||||
"image_size": 518,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_eps": 1e-06,
|
||||
"layerscale_value": 1.0,
|
||||
"mlp_ratio": 4,
|
||||
"model_type": "dinov2",
|
||||
"num_attention_heads": 24,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 40,
|
||||
"patch_size": 14,
|
||||
"qkv_bias": true,
|
||||
"use_swiglu_ffn": true,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225]
|
||||
}
|
||||
@@ -688,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
@@ -762,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
if solver_type not in {'heun', 'midpoint'}:
|
||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
old_denoised = None
|
||||
@@ -808,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
denoised_1, denoised_2 = None, None
|
||||
@@ -858,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||
@@ -867,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
@@ -876,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
||||
@@ -1366,3 +1366,59 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
|
||||
x = x + d_bar * dt
|
||||
old_d = d
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""
|
||||
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
|
||||
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
def default_noise_scaler(sigma):
|
||||
return sigma * ((sigma ** 0.3).exp() + 10.0)
|
||||
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
|
||||
num_integration_points = 200.0
|
||||
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
||||
|
||||
old_denoised = None
|
||||
old_denoised_d = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
stage_used = min(max_stage, i + 1)
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
elif stage_used == 1:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
else:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
sigma_step_size = -dt / num_integration_points
|
||||
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
|
||||
scaled_pos = noise_scaler(sigma_pos)
|
||||
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * sigma_step_size
|
||||
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
|
||||
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
|
||||
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
|
||||
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
|
||||
if s_noise != 0 and sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
@@ -19,6 +19,10 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class vector_quantize(Function):
|
||||
@staticmethod
|
||||
@@ -121,15 +125,15 @@ class ResBlock(nn.Module):
|
||||
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.depthwise = nn.Sequential(
|
||||
nn.ReplicationPad2d(1),
|
||||
nn.Conv2d(c, c, kernel_size=3, groups=c)
|
||||
ops.Conv2d(c, c, kernel_size=3, groups=c)
|
||||
)
|
||||
|
||||
# channelwise
|
||||
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.channelwise = nn.Sequential(
|
||||
nn.Linear(c, c_hidden),
|
||||
ops.Linear(c, c_hidden),
|
||||
nn.GELU(),
|
||||
nn.Linear(c_hidden, c),
|
||||
ops.Linear(c_hidden, c),
|
||||
)
|
||||
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||
@@ -171,16 +175,16 @@ class StageA(nn.Module):
|
||||
# Encoder blocks
|
||||
self.in_block = nn.Sequential(
|
||||
nn.PixelUnshuffle(2),
|
||||
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||
ops.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||
)
|
||||
down_blocks = []
|
||||
for i in range(levels):
|
||||
if i > 0:
|
||||
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||
down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
||||
down_blocks.append(block)
|
||||
down_blocks.append(nn.Sequential(
|
||||
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||
ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
||||
))
|
||||
self.down_blocks = nn.Sequential(*down_blocks)
|
||||
@@ -191,7 +195,7 @@ class StageA(nn.Module):
|
||||
|
||||
# Decoder blocks
|
||||
up_blocks = [nn.Sequential(
|
||||
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||
ops.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||
)]
|
||||
for i in range(levels):
|
||||
for j in range(bottleneck_blocks if i == 0 else 1):
|
||||
@@ -199,11 +203,11 @@ class StageA(nn.Module):
|
||||
up_blocks.append(block)
|
||||
if i < levels - 1:
|
||||
up_blocks.append(
|
||||
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||
ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||
padding=1))
|
||||
self.up_blocks = nn.Sequential(*up_blocks)
|
||||
self.out_block = nn.Sequential(
|
||||
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||
ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||
nn.PixelShuffle(2),
|
||||
)
|
||||
|
||||
@@ -232,17 +236,17 @@ class Discriminator(nn.Module):
|
||||
super().__init__()
|
||||
d = max(depth - 3, 3)
|
||||
layers = [
|
||||
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||
nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||
nn.LeakyReLU(0.2),
|
||||
]
|
||||
for i in range(depth - 1):
|
||||
c_in = c_hidden // (2 ** max((d - i), 0))
|
||||
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
||||
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||
layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||
layers.append(nn.InstanceNorm2d(c_out))
|
||||
layers.append(nn.LeakyReLU(0.2))
|
||||
self.encoder = nn.Sequential(*layers)
|
||||
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||
self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||
self.logits = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
|
||||
@@ -19,6 +19,9 @@ import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
# EfficientNet
|
||||
class EfficientNetEncoder(nn.Module):
|
||||
@@ -26,7 +29,7 @@ class EfficientNetEncoder(nn.Module):
|
||||
super().__init__()
|
||||
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
||||
self.mapper = nn.Sequential(
|
||||
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||
ops.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
||||
)
|
||||
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
|
||||
@@ -34,7 +37,7 @@ class EfficientNetEncoder(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 0.5 + 0.5
|
||||
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
|
||||
x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype)
|
||||
o = self.mapper(self.backbone(x))
|
||||
return o
|
||||
|
||||
@@ -44,39 +47,39 @@ class Previewer(nn.Module):
|
||||
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
||||
super().__init__()
|
||||
self.blocks = nn.Sequential(
|
||||
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||
ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||
ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||
ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||
ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||
ops.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@@ -105,7 +105,9 @@ class Modulation(nn.Module):
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
if vec.ndim == 2:
|
||||
vec = vec[:, None, :]
|
||||
out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
@@ -113,6 +115,20 @@ class Modulation(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
if modulation_dims is None:
|
||||
if m_add is not None:
|
||||
return tensor * m_mult + m_add
|
||||
else:
|
||||
return tensor * m_mult
|
||||
else:
|
||||
for d in modulation_dims:
|
||||
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
|
||||
if m_add is not None:
|
||||
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
|
||||
return tensor
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@@ -143,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
@@ -179,12 +195,12 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
@@ -228,9 +244,9 @@ class SingleStreamBlock(nn.Module):
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
@@ -239,7 +255,7 @@ class SingleStreamBlock(nn.Module):
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
@@ -252,8 +268,11 @@ class LastLayer(nn.Module):
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
|
||||
if vec.ndim == 2:
|
||||
vec = vec[:, None, :]
|
||||
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
|
||||
x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
@@ -10,10 +10,11 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
if pe is not None:
|
||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
@@ -36,8 +37,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
@@ -115,8 +115,11 @@ class Flux(nn.Module):
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
if img_ids is not None:
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
else:
|
||||
pe = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
|
||||
@@ -227,6 +227,7 @@ class HunyuanVideo(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
@@ -237,7 +238,17 @@ class HunyuanVideo(nn.Module):
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
if guiding_frame_index is not None:
|
||||
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
||||
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
||||
modulation_dims_txt = [(0, None, 1)]
|
||||
else:
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
modulation_dims = None
|
||||
modulation_dims_txt = None
|
||||
|
||||
if self.params.guidance_embed:
|
||||
if guidance is not None:
|
||||
@@ -264,14 +275,14 @@ class HunyuanVideo(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@@ -286,13 +297,13 @@ class HunyuanVideo(nn.Module):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@@ -303,7 +314,7 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
img = img[:, : img_len]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
shape = initial_shape[-3:]
|
||||
for i in range(len(shape)):
|
||||
@@ -313,7 +324,7 @@ class HunyuanVideo(nn.Module):
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
@@ -325,5 +336,5 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
|
||||
return out
|
||||
|
||||
@@ -24,6 +24,13 @@ if model_management.sage_attention_enabled():
|
||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||
exit(-1)
|
||||
|
||||
if model_management.flash_attention_enabled():
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
except ModuleNotFoundError:
|
||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||
exit(-1)
|
||||
|
||||
from comfy.cli_args import args
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@@ -496,6 +503,63 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
return out
|
||||
|
||||
|
||||
try:
|
||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
||||
|
||||
|
||||
@flash_attn_wrapper.register_fake
|
||||
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
|
||||
# Output shape is the same as q
|
||||
return q.new_empty(q.shape)
|
||||
except AttributeError as error:
|
||||
FLASH_ATTN_ERROR = error
|
||||
|
||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||
|
||||
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
# add a heads dimension if there isn't already one
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
try:
|
||||
assert mask is None
|
||||
out = flash_attn_wrapper(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
).transpose(1, 2)
|
||||
except Exception as e:
|
||||
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
@@ -504,6 +568,9 @@ if model_management.sage_attention_enabled():
|
||||
elif model_management.xformers_enabled():
|
||||
logging.info("Using xformers attention")
|
||||
optimized_attention = attention_xformers
|
||||
elif model_management.flash_attention_enabled():
|
||||
logging.info("Using Flash Attention")
|
||||
optimized_attention = attention_flash
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention")
|
||||
optimized_attention = attention_pytorch
|
||||
|
||||
@@ -384,6 +384,7 @@ class WanModel(torch.nn.Module):
|
||||
context,
|
||||
clip_fea=None,
|
||||
freqs=None,
|
||||
transformer_options={},
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
@@ -423,14 +424,18 @@ class WanModel(torch.nn.Module):
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
|
||||
# arguments
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
freqs=freqs,
|
||||
context=context)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
@@ -439,7 +444,7 @@ class WanModel(torch.nn.Module):
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
|
||||
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
patch_size = self.patch_size
|
||||
@@ -453,7 +458,7 @@ class WanModel(torch.nn.Module):
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w]
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
r"""
|
||||
|
||||
@@ -108,7 +108,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
|
||||
fp8 = model_config.optimizations.get("fp8", False)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
@@ -898,20 +898,31 @@ class HunyuanVideo(BaseModel):
|
||||
guidance = kwargs.get("guidance", 6.0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
|
||||
guiding_frame_index = kwargs.get("guiding_frame_index", None)
|
||||
if guiding_frame_index is not None:
|
||||
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
|
||||
|
||||
return out
|
||||
|
||||
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class HunyuanVideoI2V(HunyuanVideo):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.concat_keys = ("concat_image", "mask_inverted")
|
||||
|
||||
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
|
||||
|
||||
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.concat_keys = ("concat_image",)
|
||||
|
||||
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
|
||||
|
||||
class CosmosVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
|
||||
@@ -962,11 +973,11 @@ class WAN21(BaseModel):
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
if not self.image_to_video:
|
||||
noise = kwargs.get("noise", None)
|
||||
if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
@@ -976,6 +987,9 @@ class WAN21(BaseModel):
|
||||
image = self.process_latent_in(image)
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
if not self.image_to_video:
|
||||
return image
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(noise)[:, :4]
|
||||
|
||||
@@ -471,6 +471,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||
if model_config.scaled_fp8 == torch.float32:
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
if scaled_fp8_weight.nelement() == 2:
|
||||
model_config.optimizations["fp8"] = False
|
||||
else:
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import psutil
|
||||
import logging
|
||||
@@ -26,6 +27,10 @@ import platform
|
||||
import weakref
|
||||
import gc
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||
@@ -145,6 +150,25 @@ def get_torch_device():
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
def get_all_torch_devices(exclude_current=False):
|
||||
global cpu_state
|
||||
devices = []
|
||||
if cpu_state == CPUState.GPU:
|
||||
if is_nvidia():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
devices.append(torch.device(i))
|
||||
elif is_intel_xpu():
|
||||
for i in range(torch.xpu.device_count()):
|
||||
devices.append(torch.device(i))
|
||||
elif is_ascend_npu():
|
||||
for i in range(torch.npu.device_count()):
|
||||
devices.append(torch.device(i))
|
||||
else:
|
||||
devices.append(get_torch_device())
|
||||
if exclude_current:
|
||||
devices.remove(get_torch_device())
|
||||
return devices
|
||||
|
||||
def get_total_memory(dev=None, torch_total_too=False):
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
@@ -186,12 +210,21 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
else:
|
||||
return mem_total
|
||||
|
||||
def mac_version():
|
||||
try:
|
||||
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||
except:
|
||||
return None
|
||||
|
||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
|
||||
try:
|
||||
logging.info("pytorch version: {}".format(torch_version))
|
||||
mac_ver = mac_version()
|
||||
if mac_ver is not None:
|
||||
logging.info("Mac Version {}".format(mac_ver))
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -347,9 +380,13 @@ try:
|
||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
try:
|
||||
for device in get_all_torch_devices(exclude_current=True):
|
||||
logging.info("Device: {}".format(get_torch_device_name(device)))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
current_loaded_models = []
|
||||
current_loaded_models: list[LoadedModel] = []
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
@@ -360,7 +397,7 @@ def module_size(module):
|
||||
return module_mem
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
def __init__(self, model: ModelPatcher):
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
self.real_model = None
|
||||
@@ -368,7 +405,7 @@ class LoadedModel:
|
||||
self.model_finalizer = None
|
||||
self._patcher_finalizer = None
|
||||
|
||||
def _set_model(self, model):
|
||||
def _set_model(self, model: ModelPatcher):
|
||||
self._model = weakref.ref(model)
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
@@ -581,7 +618,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
loaded_memory = loaded_model.model_loaded_memory()
|
||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||
|
||||
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
@@ -921,6 +958,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
def sage_attention_enabled():
|
||||
return args.use_sage_attention
|
||||
|
||||
def flash_attention_enabled():
|
||||
return args.use_flash_attention
|
||||
|
||||
def xformers_enabled():
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
@@ -969,12 +1009,6 @@ def pytorch_attention_flash_attention():
|
||||
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
||||
return False
|
||||
|
||||
def mac_version():
|
||||
try:
|
||||
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||
except:
|
||||
return None
|
||||
|
||||
def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
|
||||
@@ -1213,6 +1247,31 @@ def soft_empty_cache(force=False):
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
|
||||
'Unload only model and its clones - primarily for multigpu cloning purposes.'
|
||||
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
|
||||
additional_models = []
|
||||
if unload_additional_models:
|
||||
additional_models = model.get_nested_additional_models()
|
||||
keep_loaded = []
|
||||
for loaded_model in initial_keep_loaded:
|
||||
if loaded_model.model is not None:
|
||||
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
continue
|
||||
# check additional models if they are a match
|
||||
skip = False
|
||||
for add_model in additional_models:
|
||||
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
skip = True
|
||||
break
|
||||
if skip:
|
||||
continue
|
||||
keep_loaded.append(loaded_model)
|
||||
if not all_devices:
|
||||
free_memory(1e30, get_torch_device(), keep_loaded)
|
||||
else:
|
||||
for device in get_all_torch_devices():
|
||||
free_memory(1e30, device, keep_loaded)
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
@@ -84,12 +84,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
||||
def create_model_options_clone(orig_model_options: dict):
|
||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||
|
||||
def create_hook_patches_clone(orig_hook_patches):
|
||||
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
|
||||
new_hook_patches = {}
|
||||
for hook_ref in orig_hook_patches:
|
||||
new_hook_patches[hook_ref] = {}
|
||||
for k in orig_hook_patches[hook_ref]:
|
||||
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
||||
if copy_tuples:
|
||||
for i in range(len(new_hook_patches[hook_ref][k])):
|
||||
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
|
||||
return new_hook_patches
|
||||
|
||||
def wipe_lowvram_weight(m):
|
||||
@@ -240,6 +243,9 @@ class ModelPatcher:
|
||||
self.is_clip = False
|
||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||
|
||||
self.is_multigpu_base_clone = False
|
||||
self.clone_base_uuid = uuid.uuid4()
|
||||
|
||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
@@ -318,18 +324,92 @@ class ModelPatcher:
|
||||
n.is_clip = self.is_clip
|
||||
n.hook_mode = self.hook_mode
|
||||
|
||||
n.is_multigpu_base_clone = self.is_multigpu_base_clone
|
||||
n.clone_base_uuid = self.clone_base_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
|
||||
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
|
||||
comfy.model_management.unload_model_and_clones(self)
|
||||
n = self.clone()
|
||||
# set load device, if present
|
||||
if new_load_device is not None:
|
||||
n.load_device = new_load_device
|
||||
# unlike for normal clone, backup dicts that shared same ref should not;
|
||||
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
|
||||
n.backup = copy.deepcopy(n.backup)
|
||||
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
|
||||
n.hook_backup = copy.deepcopy(n.hook_backup)
|
||||
n.model = copy.deepcopy(n.model)
|
||||
# multigpu clone should not have multigpu additional_models entry
|
||||
n.remove_additional_models("multigpu")
|
||||
# multigpu_clone all stored additional_models; make sure circular references are properly handled
|
||||
if models_cache is None:
|
||||
models_cache = {}
|
||||
for key, model_list in n.additional_models.items():
|
||||
for i in range(len(model_list)):
|
||||
add_model = n.additional_models[key][i]
|
||||
if add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
|
||||
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def match_multigpu_clones(self):
|
||||
multigpu_models = self.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
new_multigpu_models = []
|
||||
for mm in multigpu_models:
|
||||
# clone main model, but bring over relevant props from existing multigpu clone
|
||||
n = self.clone()
|
||||
n.load_device = mm.load_device
|
||||
n.backup = mm.backup
|
||||
n.object_patches_backup = mm.object_patches_backup
|
||||
n.hook_backup = mm.hook_backup
|
||||
n.model = mm.model
|
||||
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
|
||||
n.remove_additional_models("multigpu")
|
||||
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
|
||||
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
|
||||
# figure out which additional models are not present in multigpu clone
|
||||
models_cache = {}
|
||||
for mm_add_model in mm.get_additional_models():
|
||||
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
|
||||
remove_models_uuids = set(list(models_cache.keys()))
|
||||
for key, model_list in orig_additional_models.items():
|
||||
for orig_add_model in model_list:
|
||||
if orig_add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
|
||||
existing_list = n.get_additional_models_with_key(key)
|
||||
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
|
||||
n.set_additional_models(key, existing_list)
|
||||
if orig_add_model.clone_base_uuid in remove_models_uuids:
|
||||
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
|
||||
# remove duplicate additional models
|
||||
for key, model_list in n.additional_models.items():
|
||||
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
|
||||
n.set_additional_models(key, new_model_list)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
|
||||
callback(self, n)
|
||||
new_multigpu_models.append(n)
|
||||
self.set_additional_models("multigpu", new_multigpu_models)
|
||||
|
||||
def is_clone(self, other):
|
||||
if hasattr(other, 'model') and self.model is other.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
|
||||
if allow_multigpu:
|
||||
if self.clone_base_uuid != clone.clone_base_uuid:
|
||||
return False
|
||||
else:
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
|
||||
if self.current_hooks != clone.current_hooks:
|
||||
return False
|
||||
@@ -747,6 +827,7 @@ class ModelPatcher:
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0):
|
||||
with self.use_ejected():
|
||||
hooks_unpatched = False
|
||||
memory_freed = 0
|
||||
patch_counter = 0
|
||||
unload_list = self._load_list()
|
||||
@@ -770,6 +851,10 @@ class ModelPatcher:
|
||||
move_weight = False
|
||||
break
|
||||
|
||||
if not hooks_unpatched:
|
||||
self.unpatch_hooks()
|
||||
hooks_unpatched = True
|
||||
|
||||
if bk.inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||
else:
|
||||
@@ -924,7 +1009,7 @@ class ModelPatcher:
|
||||
return self.additional_models.get(key, [])
|
||||
|
||||
def get_additional_models(self):
|
||||
all_models = []
|
||||
all_models: list[ModelPatcher] = []
|
||||
for models in self.additional_models.values():
|
||||
all_models.extend(models)
|
||||
return all_models
|
||||
@@ -978,9 +1063,13 @@ class ModelPatcher:
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
||||
callback(self)
|
||||
|
||||
def prepare_state(self, timestep):
|
||||
def prepare_state(self, timestep, model_options, ignore_multigpu=False):
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
||||
callback(self, timestep)
|
||||
callback(self, timestep, model_options, ignore_multigpu)
|
||||
if not ignore_multigpu and "multigpu_clones" in model_options:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p.prepare_state(timestep, model_options, ignore_multigpu=True)
|
||||
|
||||
def restore_hook_patches(self):
|
||||
if self.hook_patches_backup is not None:
|
||||
@@ -993,12 +1082,18 @@ class ModelPatcher:
|
||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
||||
curr_t = t[0]
|
||||
reset_current_hooks = False
|
||||
multigpu_kf_changed_cache = None
|
||||
transformer_options = model_options.get("transformer_options", {})
|
||||
for hook in hook_group.hooks:
|
||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||
# this will cause the weights to be recalculated when sampling
|
||||
if changed:
|
||||
# cache changed for multigpu usage
|
||||
if "multigpu_clones" in model_options:
|
||||
if multigpu_kf_changed_cache is None:
|
||||
multigpu_kf_changed_cache = []
|
||||
multigpu_kf_changed_cache.append(hook)
|
||||
# reset current_hooks if contains hook that changed
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
@@ -1010,6 +1105,28 @@ class ModelPatcher:
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
if "multigpu_clones" in model_options:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
|
||||
|
||||
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
|
||||
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
|
||||
if kf_changed_cache is None:
|
||||
return
|
||||
reset_current_hooks = False
|
||||
# reset current_hooks if contains hook that changed
|
||||
for hook in kf_changed_cache:
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
if current_hook == hook:
|
||||
reset_current_hooks = True
|
||||
break
|
||||
for cached_group in list(self.cached_hook_patches.keys()):
|
||||
if cached_group.contains(hook):
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
|
||||
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||
registered: comfy.hooks.HookGroup = None):
|
||||
@@ -1089,7 +1206,6 @@ class ModelPatcher:
|
||||
|
||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
if hooks is not None:
|
||||
model_sd_keys = list(self.model_state_dict().keys())
|
||||
memory_counter = None
|
||||
@@ -1100,12 +1216,16 @@ class ModelPatcher:
|
||||
# if have cached weights for hooks, use it
|
||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||
if cached_weights is not None:
|
||||
model_sd_keys_set = set(model_sd_keys)
|
||||
for key in cached_weights:
|
||||
if key not in model_sd_keys:
|
||||
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||
model_sd_keys_set.remove(key)
|
||||
self.unpatch_hooks(model_sd_keys_set)
|
||||
else:
|
||||
self.unpatch_hooks()
|
||||
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
||||
original_weights = None
|
||||
if len(relevant_patches) > 0:
|
||||
@@ -1116,6 +1236,8 @@ class ModelPatcher:
|
||||
continue
|
||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||
memory_counter=memory_counter)
|
||||
else:
|
||||
self.unpatch_hooks()
|
||||
self.current_hooks = hooks
|
||||
|
||||
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||
@@ -1172,17 +1294,23 @@ class ModelPatcher:
|
||||
del out_weight
|
||||
del weight
|
||||
|
||||
def unpatch_hooks(self) -> None:
|
||||
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||
with self.use_ejected():
|
||||
if len(self.hook_backup) == 0:
|
||||
self.current_hooks = None
|
||||
return
|
||||
keys = list(self.hook_backup.keys())
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
if whitelist_keys_set:
|
||||
for k in keys:
|
||||
if k in whitelist_keys_set:
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
self.hook_backup.pop(k)
|
||||
else:
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
|
||||
self.hook_backup.clear()
|
||||
self.current_hooks = None
|
||||
self.hook_backup.clear()
|
||||
self.current_hooks = None
|
||||
|
||||
def clean_hooks(self):
|
||||
self.unpatch_hooks()
|
||||
|
||||
176
comfy/multigpu.py
Normal file
176
comfy/multigpu.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import comfy.utils
|
||||
import comfy.patcher_extension
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class GPUOptions:
|
||||
def __init__(self, device_index: int, relative_speed: float):
|
||||
self.device_index = device_index
|
||||
self.relative_speed = relative_speed
|
||||
|
||||
def clone(self):
|
||||
return GPUOptions(self.device_index, self.relative_speed)
|
||||
|
||||
def create_dict(self):
|
||||
return {
|
||||
"relative_speed": self.relative_speed
|
||||
}
|
||||
|
||||
class GPUOptionsGroup:
|
||||
def __init__(self):
|
||||
self.options: dict[int, GPUOptions] = {}
|
||||
|
||||
def add(self, info: GPUOptions):
|
||||
self.options[info.device_index] = info
|
||||
|
||||
def clone(self):
|
||||
c = GPUOptionsGroup()
|
||||
for opt in self.options.values():
|
||||
c.add(opt)
|
||||
return c
|
||||
|
||||
def register(self, model: ModelPatcher):
|
||||
opts_dict = {}
|
||||
# get devices that are valid for this model
|
||||
devices: list[torch.device] = [model.load_device]
|
||||
for extra_model in model.get_additional_models_with_key("multigpu"):
|
||||
extra_model: ModelPatcher
|
||||
devices.append(extra_model.load_device)
|
||||
# create dictionary with actual device mapped to its GPUOptions
|
||||
device_opts_list: list[GPUOptions] = []
|
||||
for device in devices:
|
||||
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
|
||||
opts_dict[device] = device_opts.create_dict()
|
||||
device_opts_list.append(device_opts)
|
||||
# make relative_speed relative to 1.0
|
||||
min_speed = min([x.relative_speed for x in device_opts_list])
|
||||
for value in opts_dict.values():
|
||||
value['relative_speed'] /= min_speed
|
||||
model.model_options['multigpu_options'] = opts_dict
|
||||
|
||||
def get_torch_device_list():
|
||||
devices = ["default"]
|
||||
for device in comfy.model_management.get_all_torch_devices():
|
||||
device: torch.device
|
||||
devices.append(str(device.index))
|
||||
return devices
|
||||
|
||||
def get_device_from_str(device_str: str, throw_error_if_not_found=False):
|
||||
if device_str == "default":
|
||||
return comfy.model_management.get_torch_device()
|
||||
for device in comfy.model_management.get_all_torch_devices():
|
||||
device: torch.device
|
||||
if str(device.index) == device_str:
|
||||
return device
|
||||
if throw_error_if_not_found:
|
||||
raise Exception(f"Device with index '{device_str}' not found.")
|
||||
logging.warning(f"Device with index '{device_str}' not found, using default device ({comfy.model_management.get_torch_device()}) instead.")
|
||||
|
||||
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
|
||||
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
|
||||
model = model.clone()
|
||||
# check if multigpu is already prepared - get the load devices from them if possible to exclude
|
||||
skip_devices = set()
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
for mm in multigpu_models:
|
||||
skip_devices.add(mm.load_device)
|
||||
skip_devices = list(skip_devices)
|
||||
|
||||
extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
||||
extra_devices = extra_devices[:max_gpus-1]
|
||||
# exclude skipped devices
|
||||
for skip in skip_devices:
|
||||
if skip in extra_devices:
|
||||
extra_devices.remove(skip)
|
||||
# create new deepclones
|
||||
if len(extra_devices) > 0:
|
||||
for device in extra_devices:
|
||||
device_patcher = None
|
||||
if reuse_loaded:
|
||||
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
|
||||
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
|
||||
for lm in loaded_models:
|
||||
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
|
||||
device_patcher = lm.clone()
|
||||
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
|
||||
break
|
||||
if device_patcher is None:
|
||||
device_patcher = model.deepclone_multigpu(new_load_device=device)
|
||||
device_patcher.is_multigpu_base_clone = True
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
multigpu_models.append(device_patcher)
|
||||
model.set_additional_models("multigpu", multigpu_models)
|
||||
model.match_multigpu_clones()
|
||||
if gpu_options is None:
|
||||
gpu_options = GPUOptionsGroup()
|
||||
gpu_options.register(model)
|
||||
else:
|
||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
|
||||
# persist skip_devices for use in sampling code
|
||||
# if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options:
|
||||
# model.model_options["multigpu_skip_devices"] = skip_devices
|
||||
return model
|
||||
|
||||
|
||||
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
||||
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
||||
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
||||
opts_dict = model_options['multigpu_options']
|
||||
devices = list(model_options['multigpu_clones'].keys())
|
||||
speed_per_device = []
|
||||
work_per_device = []
|
||||
# get sum of each device's relative_speed
|
||||
total_speed = 0.0
|
||||
for opts in opts_dict.values():
|
||||
total_speed += opts['relative_speed']
|
||||
# get relative work for each device;
|
||||
# obtained by w = (W*r)/R
|
||||
for device in devices:
|
||||
relative_speed = opts_dict[device]['relative_speed']
|
||||
relative_work = (total_work*relative_speed) / total_speed
|
||||
speed_per_device.append(relative_speed)
|
||||
work_per_device.append(relative_work)
|
||||
# relative work must be expressed in whole numbers, but likely is a decimal;
|
||||
# perform rounding while maintaining total sum equal to total work (sum of relative works)
|
||||
work_per_device = round_preserved(work_per_device)
|
||||
dict_work_per_device = {}
|
||||
for device, relative_work in zip(devices, work_per_device):
|
||||
dict_work_per_device[device] = relative_work
|
||||
if not return_idle_time:
|
||||
return LoadBalance(dict_work_per_device, None)
|
||||
# divide relative work by relative speed to get estimated completion time of said work by each device;
|
||||
# time here is relative and does not correspond to real-world units
|
||||
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
|
||||
# calculate relative time spent by the devices waiting on each other after their work is completed
|
||||
idle_time = abs(min(completion_time) - max(completion_time))
|
||||
# if need to compare work idle time, need to normalize to a common total work
|
||||
if work_normalized:
|
||||
idle_time *= (work_normalized/total_work)
|
||||
|
||||
return LoadBalance(dict_work_per_device, idle_time)
|
||||
|
||||
def round_preserved(values: list[float]):
|
||||
'Round all values in a list, preserving the combined sum of values.'
|
||||
# get floor of values; casting to int does it too
|
||||
floored = [int(x) for x in values]
|
||||
total_floored = sum(floored)
|
||||
# get remainder to distribute
|
||||
remainder = round(sum(values)) - total_floored
|
||||
# pair values with fractional portions
|
||||
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
|
||||
# sort by fractional part in descending order
|
||||
fractional.sort(key=lambda x: x[1], reverse=True)
|
||||
# distribute the remainder
|
||||
for i in range(remainder):
|
||||
index = fractional[i][0]
|
||||
floored[index] += 1
|
||||
return floored
|
||||
@@ -17,6 +17,7 @@
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
@@ -308,6 +309,7 @@ class fp8_ops(manual_cast):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
||||
class scaled_fp8_op(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -358,7 +360,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||
|
||||
if (
|
||||
fp8_compute and
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import Callable
|
||||
|
||||
class CallbacksMP:
|
||||
ON_CLONE = "on_clone"
|
||||
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
|
||||
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
|
||||
ON_LOAD = "on_load_after"
|
||||
ON_DETACH = "on_detach_after"
|
||||
ON_CLEANUP = "on_cleanup"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import uuid
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import comfy.hooks
|
||||
import comfy.patcher_extension
|
||||
@@ -104,16 +106,57 @@ def cleanup_additional_models(models):
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
|
||||
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
|
||||
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
|
||||
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) == 0:
|
||||
return
|
||||
extra_devices = [x.load_device for x in multigpu_models]
|
||||
# handle controlnets
|
||||
controlnets: set[ControlBase] = set()
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
if 'control' in kk:
|
||||
controlnets.add(kk['control'])
|
||||
if len(controlnets) > 0:
|
||||
# first, unload all controlnet clones
|
||||
for cnet in list(controlnets):
|
||||
cnet_models = cnet.get_models()
|
||||
for cm in cnet_models:
|
||||
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
|
||||
|
||||
# next, make sure each controlnet has a deepclone for all relevant devices
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
for device in extra_devices:
|
||||
if device not in curr_cnet.multigpu_clones:
|
||||
curr_cnet.deepclone_multigpu(device, autoregister=True)
|
||||
curr_cnet = curr_cnet.previous_controlnet
|
||||
# since all device clones are now present, recreate the linked list for cloned cnets per device
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
prev_cnet = curr_cnet.previous_controlnet
|
||||
for device in extra_devices:
|
||||
device_cnet = curr_cnet.get_instance_for_device(device)
|
||||
prev_device_cnet = None
|
||||
if prev_cnet is not None:
|
||||
prev_device_cnet = prev_cnet.get_instance_for_device(device)
|
||||
device_cnet.set_previous_controlnet(prev_device_cnet)
|
||||
curr_cnet = prev_cnet
|
||||
# potentially handle gligen - since not widely used, ignored for now
|
||||
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
real_model: BaseModel = None
|
||||
model.match_multigpu_clones()
|
||||
preprocess_multigpu_conds(conds, model, model_options)
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
|
||||
real_model = model.model
|
||||
real_model: BaseModel = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
@@ -126,7 +169,7 @@ def cleanup_models(conds, models):
|
||||
|
||||
cleanup_additional_models(set(control_cleanup))
|
||||
|
||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
||||
'''
|
||||
Registers hooks from conds.
|
||||
'''
|
||||
@@ -159,3 +202,18 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||
copy_dict1=False)
|
||||
return to_load_options
|
||||
|
||||
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
|
||||
'''
|
||||
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
|
||||
'''
|
||||
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
|
||||
if len(multigpu_patchers) > 0:
|
||||
multigpu_dict: dict[torch.device, ModelPatcher] = {}
|
||||
multigpu_dict[model_patcher.load_device] = model_patcher
|
||||
for x in multigpu_patchers:
|
||||
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
|
||||
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
|
||||
multigpu_dict[x.load_device] = x
|
||||
model_options["multigpu_clones"] = multigpu_dict
|
||||
return multigpu_patchers
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import comfy.model_management
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||
@@ -18,6 +20,7 @@ import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import scipy.stats
|
||||
import numpy
|
||||
import threading
|
||||
|
||||
|
||||
def add_area_dims(area, num_dims):
|
||||
@@ -140,7 +143,7 @@ def can_concat_cond(c1, c2):
|
||||
|
||||
return cond_equal_size(c1.conditioning, c2.conditioning)
|
||||
|
||||
def cond_cat(c_list):
|
||||
def cond_cat(c_list, device=None):
|
||||
temp = {}
|
||||
for x in c_list:
|
||||
for k in x:
|
||||
@@ -152,6 +155,8 @@ def cond_cat(c_list):
|
||||
for k in temp:
|
||||
conds = temp[k]
|
||||
out[k] = conds[0].concat(conds[1:])
|
||||
if device is not None and hasattr(out[k], 'to'):
|
||||
out[k] = out[k].to(device)
|
||||
|
||||
return out
|
||||
|
||||
@@ -205,7 +210,9 @@ def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Ten
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
if 'multigpu_clones' in model_options:
|
||||
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
@@ -237,7 +244,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep)
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
@@ -339,6 +346,190 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
||||
|
||||
return out_conds
|
||||
|
||||
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||
default_conds = []
|
||||
has_default_conds = False
|
||||
|
||||
output_device = x_in.device
|
||||
|
||||
for i in range(len(conds)):
|
||||
out_conds.append(torch.zeros_like(x_in))
|
||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||
|
||||
cond = conds[i]
|
||||
default_c = []
|
||||
if cond is not None:
|
||||
for x in cond:
|
||||
if 'default' in x:
|
||||
default_c.append(x)
|
||||
has_default_conds = True
|
||||
continue
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
default_conds.append(default_c)
|
||||
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
|
||||
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
|
||||
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
|
||||
|
||||
total_conds = 0
|
||||
for to_run in hooked_to_run.values():
|
||||
total_conds += len(to_run)
|
||||
conds_per_device = max(1, math.ceil(total_conds//len(devices)))
|
||||
index_device = 0
|
||||
current_device = devices[index_device]
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
while len(to_run) > 0:
|
||||
current_device = devices[index_device % len(devices)]
|
||||
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
|
||||
# keep track of conds currently scheduled onto this device
|
||||
batched_to_run_length = 0
|
||||
for btr in batched_to_run:
|
||||
batched_to_run_length += len(btr[1])
|
||||
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
# make sure not over conds_per_device limit when creating temp batch
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length):
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = model_management.get_free_memory(current_device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
conds_to_batch = []
|
||||
for x in to_batch:
|
||||
conds_to_batch.append(to_run.pop(x))
|
||||
batched_to_run_length += len(conds_to_batch)
|
||||
|
||||
batched_to_run.append((hooks, conds_to_batch))
|
||||
if batched_to_run_length >= conds_per_device:
|
||||
index_device += 1
|
||||
|
||||
thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond'])
|
||||
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||
# run every hooked_to_run separately
|
||||
with torch.no_grad():
|
||||
for hooks, to_batch in batch_tuple:
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
uuids = []
|
||||
area = []
|
||||
control: ControlBase = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = x
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
uuids.append(p.uuid)
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x).to(device)
|
||||
c = cond_cat(c, device=device)
|
||||
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
|
||||
|
||||
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||
model_options['transformer_options'],
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
# TODO: replace with merge_nested_dicts function
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
transformer_options["patches"] = cur_patches
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
|
||||
transformer_options["multigpu_thread_device"] = device
|
||||
|
||||
cast_transformer_options(transformer_options, device=device)
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if control is not None:
|
||||
device_control = control.get_instance_for_device(device)
|
||||
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
||||
else:
|
||||
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
|
||||
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
|
||||
|
||||
|
||||
results: list[thread_result] = []
|
||||
threads: list[threading.Thread] = []
|
||||
for device, batch_tuple in device_batched_hooked_to_run.items():
|
||||
new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results))
|
||||
threads.append(new_thread)
|
||||
new_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
for output, mult, area, batch_chunks, cond_or_uncond in results:
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
|
||||
return out_conds
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||
@@ -636,6 +827,8 @@ def pre_run_control(model, conds):
|
||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||
if 'control' in x:
|
||||
x['control'].pre_run(model, percent_to_timestep_function)
|
||||
for device_cnet in x['control'].multigpu_clones.values():
|
||||
device_cnet.pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
cond_cnets = []
|
||||
@@ -710,7 +903,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||
"gradient_estimation"]
|
||||
"gradient_estimation", "er_sde"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
@@ -878,7 +1071,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
to_load_options = model_options.get("to_load_options", None)
|
||||
if to_load_options is None:
|
||||
return
|
||||
cast_transformer_options(to_load_options, device, dtype)
|
||||
|
||||
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
|
||||
casts = []
|
||||
if device is not None:
|
||||
casts.append(device)
|
||||
@@ -887,18 +1082,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
# if nothing to apply, do nothing
|
||||
if len(casts) == 0:
|
||||
return
|
||||
|
||||
# try to call .to on patches
|
||||
if "patches" in to_load_options:
|
||||
patches = to_load_options["patches"]
|
||||
if "patches" in transformer_options:
|
||||
patches = transformer_options["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
for cast in casts:
|
||||
patch_list[i] = patch_list[i].to(cast)
|
||||
if "patches_replace" in to_load_options:
|
||||
patches = to_load_options["patches_replace"]
|
||||
if "patches_replace" in transformer_options:
|
||||
patches = transformer_options["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
@@ -908,8 +1102,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
# try to call .to on any wrappers/callbacks
|
||||
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||
for wc_name in wrappers_and_callbacks:
|
||||
if wc_name in to_load_options:
|
||||
wc: dict[str, list] = to_load_options[wc_name]
|
||||
if wc_name in transformer_options:
|
||||
wc: dict[str, list] = transformer_options[wc_name]
|
||||
for wc_dict in wc.values():
|
||||
for wc_list in wc_dict.values():
|
||||
for i in range(len(wc_list)):
|
||||
@@ -917,7 +1111,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
for cast in casts:
|
||||
wc_list[i] = wc_list[i].to(cast)
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher: ModelPatcher):
|
||||
self.model_patcher = model_patcher
|
||||
@@ -963,6 +1156,8 @@ class CFGGuider:
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
|
||||
|
||||
if denoise_mask is not None:
|
||||
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
||||
|
||||
@@ -973,9 +1168,13 @@ class CFGGuider:
|
||||
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
|
||||
15
comfy/sd.py
15
comfy/sd.py
@@ -440,6 +440,10 @@ class VAE:
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
|
||||
def throw_exception_if_invalid(self):
|
||||
if self.first_stage_model is None:
|
||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
|
||||
@@ -495,6 +499,7 @@ class VAE:
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = None
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
@@ -525,6 +530,7 @@ class VAE:
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
self.throw_exception_if_invalid()
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
dims = samples.ndim - 2
|
||||
@@ -553,6 +559,7 @@ class VAE:
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||
@@ -585,6 +592,7 @@ class VAE:
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
dims = self.latent_dim
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
@@ -899,7 +907,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||
if model_config is None:
|
||||
return None
|
||||
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
|
||||
if diffusion_model is None:
|
||||
return None
|
||||
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.scaled_fp8 is not None:
|
||||
|
||||
@@ -931,7 +931,7 @@ class WAN21_T2V(supported_models_base.BASE):
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
@@ -42,7 +42,7 @@ class HunyuanVideoTokenizer:
|
||||
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
|
||||
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, **kwargs):
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
@@ -56,7 +56,7 @@ class HunyuanVideoTokenizer:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 128257:
|
||||
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
||||
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image", "image_interleave": image_interleave},) + r[i][1:]
|
||||
embed_count += 1
|
||||
out["llama"] = llama_text_tokens
|
||||
return out
|
||||
@@ -92,10 +92,10 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
||||
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||
|
||||
template_end = 0
|
||||
image_start = None
|
||||
image_end = None
|
||||
extra_template_end = 0
|
||||
extra_sizes = 0
|
||||
user_end = 9999999999999
|
||||
images = []
|
||||
|
||||
tok_pairs = token_weight_pairs_llama[0]
|
||||
for i, v in enumerate(tok_pairs):
|
||||
@@ -112,22 +112,28 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
||||
else:
|
||||
if elem.get("original_type") == "image":
|
||||
elem_size = elem.get("data").shape[0]
|
||||
if image_start is None:
|
||||
if template_end > 0:
|
||||
if user_end == -1:
|
||||
extra_template_end += elem_size - 1
|
||||
else:
|
||||
image_start = i + extra_sizes
|
||||
image_end = i + elem_size + extra_sizes
|
||||
extra_sizes += elem_size - 1
|
||||
images.append((image_start, image_end, elem.get("image_interleave", 1)))
|
||||
extra_sizes += elem_size - 1
|
||||
|
||||
if llama_out.shape[1] > (template_end + 2):
|
||||
if tok_pairs[template_end + 1][0] == 271:
|
||||
template_end += 2
|
||||
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes]
|
||||
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes]
|
||||
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
||||
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
||||
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
|
||||
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
|
||||
|
||||
if image_start is not None:
|
||||
image_output = llama_out[:, image_start: image_end]
|
||||
llama_output = torch.cat([image_output[:, ::2], llama_output], dim=1)
|
||||
if len(images) > 0:
|
||||
out = []
|
||||
for i in images:
|
||||
out.append(llama_out[:, i[0]: i[1]: i[2]])
|
||||
llama_output = torch.cat(out + [llama_output], dim=1)
|
||||
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
return llama_output, l_pooled, llama_extra_out
|
||||
|
||||
@@ -57,17 +57,17 @@ class TextEncodeHunyuanVideo_ImageToVideo:
|
||||
"clip": ("CLIP", ),
|
||||
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def encode(self, clip, clip_vision_output, prompt):
|
||||
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected)
|
||||
def encode(self, clip, clip_vision_output, prompt, image_interleave):
|
||||
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
|
||||
class HunyuanImageToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -77,6 +77,7 @@ class HunyuanImageToVideo:
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
|
||||
},
|
||||
"optional": {"start_image": ("IMAGE", ),
|
||||
}}
|
||||
@@ -87,8 +88,10 @@ class HunyuanImageToVideo:
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, vae, width, height, length, batch_size, start_image=None):
|
||||
def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
out_latent = {}
|
||||
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
|
||||
@@ -96,13 +99,20 @@ class HunyuanImageToVideo:
|
||||
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
if guidance_type == "v1 (concat)":
|
||||
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
|
||||
else:
|
||||
cond = {'guiding_frame_index': 0}
|
||||
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
|
||||
out_latent["noise_mask"] = mask
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, cond)
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, out_latent)
|
||||
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
||||
|
||||
@@ -19,8 +19,6 @@ class Load3D():
|
||||
"image": ("LOAD_3D", {}),
|
||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||
@@ -55,8 +53,6 @@ class Load3DAnimation():
|
||||
"image": ("LOAD_3D_ANIMATION", {}),
|
||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||
@@ -82,8 +78,6 @@ class Preview3D():
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
@@ -102,8 +96,6 @@ class Preview3DAnimation():
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@@ -99,12 +99,13 @@ class LTXVAddGuide:
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE",),
|
||||
"latent": ("LATENT",),
|
||||
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \
|
||||
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames."
|
||||
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
|
||||
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
|
||||
"tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \
|
||||
"If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \
|
||||
"Negative values are counted from the end of the video."}),
|
||||
"tooltip": "Frame index to start the conditioning at. For single-frame images or "
|
||||
"videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
|
||||
"frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
|
||||
"the nearest multiple of 8. Negative values are counted from the end of the video."}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
@@ -127,12 +128,13 @@ class LTXVAddGuide:
|
||||
t = vae.encode(encode_pixels)
|
||||
return encode_pixels, t
|
||||
|
||||
def get_latent_index(self, cond, latent_length, frame_idx, scale_factors):
|
||||
def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors):
|
||||
time_scale_factor, _, _ = scale_factors
|
||||
_, num_keyframes = get_keyframe_idxs(cond)
|
||||
latent_count = latent_length - num_keyframes
|
||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0)
|
||||
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
|
||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||
if guide_length > 1:
|
||||
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
|
||||
|
||||
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
|
||||
|
||||
@@ -191,7 +193,7 @@ class LTXVAddGuide:
|
||||
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
||||
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
|
||||
|
||||
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors)
|
||||
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
|
||||
|
||||
108
comfy_extras/nodes_multigpu.py
Normal file
108
comfy_extras/nodes_multigpu.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import comfy.multigpu
|
||||
|
||||
from nodes import VAELoader
|
||||
|
||||
|
||||
class VAELoaderDevice(VAELoader):
|
||||
NodeId = "VAELoaderDevice"
|
||||
NodeName = "Load VAE MultiGPU"
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"vae_name": (cls.vae_list(), ),
|
||||
"load_device": (comfy.multigpu.get_torch_device_list(), ),
|
||||
}
|
||||
}
|
||||
|
||||
FUNCTION = "load_vae_device"
|
||||
CATEGORY = "advanced/multigpu/loaders"
|
||||
|
||||
def load_vae_device(self, vae_name, load_device: str):
|
||||
device = comfy.multigpu.get_device_from_str(load_device)
|
||||
return self.load_vae(vae_name, device)
|
||||
|
||||
class MultiGPUWorkUnitsNode:
|
||||
"""
|
||||
Prepares model to have sampling accelerated via splitting work units.
|
||||
|
||||
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
|
||||
|
||||
Other than those exceptions, this node can be placed in any order.
|
||||
"""
|
||||
|
||||
NodeId = "MultiGPU_WorkUnits"
|
||||
NodeName = "MultiGPU Work Units"
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
"gpu_options": ("GPU_OPTIONS",)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "init_multigpu"
|
||||
CATEGORY = "advanced/multigpu"
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
|
||||
def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
|
||||
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True)
|
||||
return (model,)
|
||||
|
||||
class MultiGPUOptionsNode:
|
||||
"""
|
||||
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
|
||||
"""
|
||||
|
||||
NodeId = "MultiGPU_Options"
|
||||
NodeName = "MultiGPU Options"
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"device_index": ("INT", {"default": 0, "min": 0, "max": 64}),
|
||||
"relative_speed": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.01})
|
||||
},
|
||||
"optional": {
|
||||
"gpu_options": ("GPU_OPTIONS",)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("GPU_OPTIONS",)
|
||||
FUNCTION = "create_gpu_options"
|
||||
CATEGORY = "advanced/multigpu"
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
|
||||
def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
|
||||
if not gpu_options:
|
||||
gpu_options = comfy.multigpu.GPUOptionsGroup()
|
||||
gpu_options.clone()
|
||||
|
||||
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
|
||||
gpu_options.add(opt)
|
||||
|
||||
return (gpu_options,)
|
||||
|
||||
|
||||
node_list = [
|
||||
MultiGPUWorkUnitsNode,
|
||||
MultiGPUOptionsNode,
|
||||
VAELoaderDevice,
|
||||
]
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
|
||||
for node in node_list:
|
||||
NODE_CLASS_MAPPINGS[node.NodeId] = node
|
||||
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.24"
|
||||
__version__ = "0.3.26"
|
||||
|
||||
@@ -634,6 +634,13 @@ def validate_inputs(prompt, item, validated):
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
# Unwraps values wrapped in __value__ key. This is used to pass
|
||||
# list widget value to execution, as by default list value is
|
||||
# reserved to represent the connection between nodes.
|
||||
if isinstance(val, dict) and "__value__" in val:
|
||||
val = val["__value__"]
|
||||
inputs[x] = val
|
||||
|
||||
if type_input == "INT":
|
||||
val = int(val)
|
||||
inputs[x] = val
|
||||
|
||||
6
main.py
6
main.py
@@ -139,6 +139,7 @@ from server import BinaryEventTypes
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
import app.logger
|
||||
|
||||
|
||||
def cuda_malloc_warning():
|
||||
@@ -295,9 +296,12 @@ def start_comfyui(asyncio_loop=None):
|
||||
if __name__ == "__main__":
|
||||
# Running directly, just start ComfyUI.
|
||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||
|
||||
event_loop, _, start_all_func = start_comfyui()
|
||||
try:
|
||||
event_loop.run_until_complete(start_all_func())
|
||||
x = start_all_func()
|
||||
app.logger.print_startup_warnings()
|
||||
event_loop.run_until_complete(x)
|
||||
except KeyboardInterrupt:
|
||||
logging.info("\nStopped server")
|
||||
|
||||
|
||||
17
nodes.py
17
nodes.py
@@ -489,7 +489,7 @@ class SaveLatent:
|
||||
file = os.path.join(full_output_folder, file)
|
||||
|
||||
output = {}
|
||||
output["latent_tensor"] = samples["samples"]
|
||||
output["latent_tensor"] = samples["samples"].contiguous()
|
||||
output["latent_format_version_0"] = torch.tensor([])
|
||||
|
||||
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
||||
@@ -763,13 +763,14 @@ class VAELoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
#TODO: scale factor?
|
||||
def load_vae(self, vae_name):
|
||||
def load_vae(self, vae_name, device=None):
|
||||
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||
sd = self.load_taesd(vae_name)
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = comfy.sd.VAE(sd=sd)
|
||||
vae = comfy.sd.VAE(sd=sd, device=device)
|
||||
vae.throw_exception_if_invalid()
|
||||
return (vae,)
|
||||
|
||||
class ControlNetLoader:
|
||||
@@ -1785,14 +1786,7 @@ class LoadImageOutput(LoadImage):
|
||||
|
||||
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
|
||||
EXPERIMENTAL = True
|
||||
FUNCTION = "load_image_output"
|
||||
|
||||
def load_image_output(self, image):
|
||||
return self.load_image(f"{image} [output]")
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, image):
|
||||
return True
|
||||
FUNCTION = "load_image"
|
||||
|
||||
|
||||
class ImageScale:
|
||||
@@ -2265,6 +2259,7 @@ def init_builtin_extra_nodes():
|
||||
"nodes_mahiro.py",
|
||||
"nodes_lt.py",
|
||||
"nodes_hooks.py",
|
||||
"nodes_multigpu.py",
|
||||
"nodes_load_3d.py",
|
||||
"nodes_cosmos.py",
|
||||
"nodes_video.py",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.24"
|
||||
version = "0.3.26"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.10.17
|
||||
comfyui-frontend-package==1.12.14
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
||||
@@ -70,7 +70,7 @@ def test_get_release_invalid_version(mock_provider):
|
||||
def test_init_frontend_default():
|
||||
version_string = DEFAULT_VERSION_STRING
|
||||
frontend_path = FrontendManager.init_frontend(version_string)
|
||||
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
|
||||
assert frontend_path == FrontendManager.default_frontend_path()
|
||||
|
||||
|
||||
def test_init_frontend_invalid_version():
|
||||
@@ -84,24 +84,29 @@ def test_init_frontend_invalid_provider():
|
||||
with pytest.raises(HTTPError):
|
||||
FrontendManager.init_frontend_unsafe(version_string)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_os_functions():
|
||||
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
|
||||
patch('app.frontend_management.os.listdir') as mock_listdir, \
|
||||
patch('app.frontend_management.os.rmdir') as mock_rmdir:
|
||||
with (
|
||||
patch("app.frontend_management.os.makedirs") as mock_makedirs,
|
||||
patch("app.frontend_management.os.listdir") as mock_listdir,
|
||||
patch("app.frontend_management.os.rmdir") as mock_rmdir,
|
||||
):
|
||||
mock_listdir.return_value = [] # Simulate empty directory
|
||||
yield mock_makedirs, mock_listdir, mock_rmdir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_download():
|
||||
with patch('app.frontend_management.download_release_asset_zip') as mock:
|
||||
with patch("app.frontend_management.download_release_asset_zip") as mock:
|
||||
mock.side_effect = Exception("Download failed") # Simulate download failure
|
||||
yield mock
|
||||
|
||||
|
||||
def test_finally_block(mock_os_functions, mock_download, mock_provider):
|
||||
# Arrange
|
||||
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
|
||||
version_string = 'test-owner/test-repo@1.0.0'
|
||||
version_string = "test-owner/test-repo@1.0.0"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception):
|
||||
@@ -128,3 +133,42 @@ def test_parse_version_string_invalid():
|
||||
version_string = "invalid"
|
||||
with pytest.raises(argparse.ArgumentTypeError):
|
||||
FrontendManager.parse_version_string(version_string)
|
||||
|
||||
|
||||
def test_init_frontend_default_with_mocks():
|
||||
# Arrange
|
||||
version_string = DEFAULT_VERSION_STRING
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||
patch.object(
|
||||
FrontendManager, "default_frontend_path", return_value="/mocked/path"
|
||||
),
|
||||
):
|
||||
frontend_path = FrontendManager.init_frontend(version_string)
|
||||
|
||||
# Assert
|
||||
assert frontend_path == "/mocked/path"
|
||||
mock_check.assert_called_once()
|
||||
|
||||
|
||||
def test_init_frontend_fallback_on_error():
|
||||
# Arrange
|
||||
version_string = "test-owner/test-repo@1.0.0"
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch.object(
|
||||
FrontendManager, "init_frontend_unsafe", side_effect=Exception("Test error")
|
||||
),
|
||||
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||
patch.object(
|
||||
FrontendManager, "default_frontend_path", return_value="/default/path"
|
||||
),
|
||||
):
|
||||
frontend_path = FrontendManager.init_frontend(version_string)
|
||||
|
||||
# Assert
|
||||
assert frontend_path == "/default/path"
|
||||
mock_check.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user