Compare commits

...

59 Commits

Author SHA1 Message Date
Jedrzej Kosinski
c8037ab667 Initial exploration of weight zipper 2025-03-24 03:34:42 -05:00
comfyanonymous
3b19fc76e3 Allow disabling pe in flux code for some other models. 2025-03-18 05:09:25 -04:00
comfyanonymous
50614f1b79 Fix regression with clip vision. 2025-03-17 13:56:11 -04:00
comfyanonymous
6dc7b0bfe3 Add support for giant dinov2 image encoder. 2025-03-17 05:53:54 -04:00
comfyanonymous
e8e990d6b8 Cleanup code. 2025-03-16 06:29:12 -04:00
Jedrzej Kosinski
2e24a15905 Call unpatch_hooks at the start of ModelPatcher.partially_unload (#7253)
* Call unpatch_hooks at the start of ModelPatcher.partially_unload

* Only call unpatch_hooks in partially_unload if lowvram is possible
2025-03-16 06:02:45 -04:00
chaObserv
fd5297131f Guard the edge cases of noise term in er_sde (#7265) 2025-03-16 06:02:25 -04:00
comfyanonymous
55a1b09ddc Allow loading diffusion model files with the "Load Checkpoint" node. 2025-03-15 08:27:49 -04:00
comfyanonymous
3c3988df45 Show a better error message if the VAE is invalid. 2025-03-15 08:26:36 -04:00
Christian Byrne
7ebd8087ff hotfix fe (#7244) 2025-03-15 01:38:10 -04:00
Chenlei Hu
c624c29d66 Update frontend to 1.12.9 (#7236)
* Update frontend to 1.12.9

* Update requirements.txt
2025-03-14 18:17:26 -04:00
comfyanonymous
a2448fc527 Remove useless code. 2025-03-14 18:10:37 -04:00
comfyanonymous
6a0daa79b6 Make the SkipLayerGuidanceDIT node work on WAN. 2025-03-14 10:55:19 -04:00
FeepingCreature
9c98c6358b Tolerate missing @torch.library.custom_op (#7234)
This can happen on Pytorch versions older than 2.4.
2025-03-14 09:51:26 -04:00
FeepingCreature
7aceb9f91c Add --use-flash-attention flag. (#7223)
* Add --use-flash-attention flag.
This is useful on AMD systems, as FA builds are still 10% faster than Pytorch cross-attention.
2025-03-14 03:22:41 -04:00
comfyanonymous
35504e2f93 Fix. 2025-03-13 15:03:18 -04:00
comfyanonymous
299436cfed Print mac version. 2025-03-13 10:05:40 -04:00
Chenlei Hu
52e566d2bc Add codeowner for comfy/comfy_types (#7213) 2025-03-12 17:30:00 -04:00
Chenlei Hu
9b6cd9b874 [NodeDef] Add documentation on multi_select input option (#7212) 2025-03-12 17:29:39 -04:00
chaObserv
3fc688aebd Ensure the extra_args in dpmpp sde series (#7204) 2025-03-12 17:28:59 -04:00
comfyanonymous
f4411250f3 Repeat frontend version warning at the end.
This way someone running ComfyUI with the command line is more likely to
actually see it.
2025-03-12 07:13:40 -04:00
Chenlei Hu
d2a0fb6bb0 Add unwrap widget value support (#7197)
* Add unwrap widget value support

* nit
2025-03-12 06:39:14 -04:00
chaObserv
01015bff16 Add er_sde sampler (#7187) 2025-03-12 02:42:37 -04:00
comfyanonymous
2330754b0e Fix error saving some latents. 2025-03-11 15:07:16 -04:00
comfyanonymous
bc219a6487 Merge pull request #7143 from christian-byrne/fix-remote-widget-node
Fix LoadImageOutput node
2025-03-11 04:30:25 -04:00
comfyanonymous
94689766ad Merge pull request #7179 from comfyanonymous/ignore_fe_package
Only check frontend package if using default frontend
2025-03-11 03:45:02 -04:00
huchenlei
cfbe4b49ca Access package version 2025-03-10 20:43:59 -04:00
comfyanonymous
ca8efab79f Support control loras on Wan. 2025-03-10 17:23:13 -04:00
Chenlei Hu
65ea778a5e nit 2025-03-10 15:19:59 -04:00
Chenlei Hu
db9f2a34fc Fix unit test 2025-03-10 15:19:52 -04:00
Chenlei Hu
7946049794 nit 2025-03-10 15:14:40 -04:00
Chenlei Hu
6f6349b6a7 nit 2025-03-10 15:10:40 -04:00
Chenlei Hu
1f138dd382 Only check frontend package if using default frontend 2025-03-10 15:07:44 -04:00
comfyanonymous
b779349b55 Temporarily revert fix to give time for people to update their nodes. 2025-03-10 06:30:17 -04:00
comfyanonymous
35e2dcf5d7 Hack to fix broken manager. 2025-03-10 06:15:17 -04:00
Andrew Kvochko
67c7184b74 ltxv: relax frame_idx divisibility for single frames. (#7146)
This commit relaxes divisibility constraint for single-frame
conditionings. For single frames, the index can be arbitrary, while
multi-frame conditionings (>= 9 frames) must still be aligned to 8
frames.

Co-authored-by: Andrew Kvochko <a.kvochko@lightricks.com>
2025-03-10 04:11:48 -04:00
comfyanonymous
6f8e766509 Prevent custom nodes from accidentally overwriting global modules. 2025-03-10 03:33:41 -04:00
Terry Jia
e1da98a14a remove unused params (#6931) 2025-03-09 14:07:09 -04:00
bymyself
a73410aafa remove overrides 2025-03-09 03:46:08 -07:00
comfyanonymous
9aac21f894 Fix issues with new hunyuan img2vid model and bumb version to v0.3.26 2025-03-09 05:07:22 -04:00
Jedrzej Kosinski
528d1b3563 When cached_hook_patches contain weights for hooks, only use hook_backup for unused keys (#7067) 2025-03-09 04:26:31 -04:00
comfyanonymous
2bc4b5968f ComfyUI version v0.3.25 2025-03-09 03:30:20 -04:00
comfyanonymous
7395b0c0d1 Support new hunyuan video i2v model.
Use the new "v2 (replace)" guidance type in HunyuanImageToVideo and set
image_interleave to 4 on the "Text Encode Hunyuan Video" node.
2025-03-08 20:34:47 -05:00
comfyanonymous
0952569493 Fix stable cascade VAE on some lowvram machines. 2025-03-08 20:24:04 -05:00
comfyanonymous
29832b3b61 Warn if frontend package is older than the one in requirements.txt 2025-03-08 03:51:36 -05:00
comfyanonymous
be4e760648 Add an image_interleave option to the Hunyuan image to video encode node.
See the tooltip for what it does.
2025-03-07 19:56:26 -05:00
comfyanonymous
c3d9cc4592 Print the frontend version in the log. 2025-03-07 19:56:26 -05:00
Chenlei Hu
84cc9cb528 Update frontend to 1.11.8 (#7119)
* Update frontend to 1.11.7

* Update requirements.txt
2025-03-07 19:02:13 -05:00
comfyanonymous
ebbb920163 Add back taesd to nightly package. 2025-03-07 14:56:09 -05:00
comfyanonymous
d60fe0af4a Reduce size of nightly package. 2025-03-07 08:30:01 -05:00
comfyanonymous
5dbd250965 Update nightly instructions in readme. 2025-03-07 07:57:59 -05:00
comfyanonymous
4ab1875283 Add .bat file to nightly package to run with fp16 accumulation. 2025-03-07 07:45:40 -05:00
comfyanonymous
11b1f27cb1 Set WAN default compute dtype to fp16. 2025-03-07 04:52:36 -05:00
comfyanonymous
70e15fd743 No need for scale_input when fp8 matrix mult is disabled. 2025-03-07 04:49:20 -05:00
comfyanonymous
e1474150de Support fp8_scaled diffusion models that don't use fp8 matrix mult. 2025-03-07 04:39:21 -05:00
JettHu
e62d72e8ca Typo in node_typing.py (#7092) 2025-03-06 15:24:04 -05:00
Dr.Lt.Data
1650cda030 Fixed: Incorrect guide message for missing frontend. (#7105)
`{sys.executable} -m pip` -> `{sys.executable} -s -m pip`

https://github.com/comfyanonymous/ComfyUI/pull/7047#issuecomment-2697876793
2025-03-06 15:23:23 -05:00
comfyanonymous
a13125840c ComfyUI version v0.3.24 2025-03-06 13:53:48 -05:00
comfyanonymous
dfa36e6855 Fix some things breaking when embeddings fail to apply. 2025-03-06 13:31:55 -05:00
40 changed files with 897 additions and 177 deletions

View File

@@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
pause

View File

@@ -7,7 +7,7 @@ on:
description: 'cuda version' description: 'cuda version'
required: true required: true
type: string type: string
default: "126" default: "128"
python_minor: python_minor:
description: 'python minor version' description: 'python minor version'
@@ -19,7 +19,7 @@ on:
description: 'python patch version' description: 'python patch version'
required: true required: true
type: string type: string
default: "1" default: "2"
# push: # push:
# branches: # branches:
# - master # - master
@@ -34,7 +34,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 30
persist-credentials: false persist-credentials: false
- uses: actions/setup-python@v5 - uses: actions/setup-python@v5
with: with:
@@ -74,7 +74,7 @@ jobs:
pause" > ./update/update_comfyui_and_python_dependencies.bat pause" > ./update/update_comfyui_and_python_dependencies.bat
cd .. cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
cd ComfyUI_windows_portable_nightly_pytorch cd ComfyUI_windows_portable_nightly_pytorch

View File

@@ -19,5 +19,6 @@
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata /app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata /utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
# Extra nodes # Node developers
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink /comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered

View File

@@ -215,9 +215,9 @@ Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126``` ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
This is the command to install pytorch nightly instead which might have performance improvements: This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
#### Troubleshooting #### Troubleshooting

View File

@@ -11,20 +11,43 @@ from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import TypedDict, Optional from typing import TypedDict, Optional
from importlib.metadata import version
import requests import requests
from typing_extensions import NotRequired from typing_extensions import NotRequired
from comfy.cli_args import DEFAULT_VERSION_STRING from comfy.cli_args import DEFAULT_VERSION_STRING
import app.logger
# The path to the requirements.txt file
req_path = Path(__file__).parents[1] / "requirements.txt"
def frontend_install_warning_message():
"""The warning message to display when the frontend version is not up to date."""
extra = ""
if sys.flags.no_user_site:
extra = "-s "
return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem"
try: def check_frontend_version():
import comfyui_frontend_package """Check if the frontend version is up to date."""
except ImportError:
# TODO: Remove the check after roll out of 0.3.16 def parse_version(version: str) -> tuple[int, int, int]:
req_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt')) return tuple(map(int, version.split(".")))
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. Please install the updated requirements.txt file by running:\n{sys.executable} -m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n")
exit(-1) try:
frontend_version_str = version("comfyui-frontend-package")
frontend_version = parse_version(frontend_version_str)
with open(req_path, "r", encoding="utf-8") as f:
required_frontend = parse_version(f.readline().split("=")[-1])
if frontend_version < required_frontend:
app.logger.log_startup_warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message()))
else:
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
except Exception as e:
logging.error(f"Failed to check frontend version: {e}")
REQUEST_TIMEOUT = 10 # seconds REQUEST_TIMEOUT = 10 # seconds
@@ -121,9 +144,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager: class FrontendManager:
DEFAULT_FRONTEND_PATH = str(importlib.resources.files(comfyui_frontend_package) / "static")
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions") CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
def default_frontend_path(cls) -> str:
try:
import comfyui_frontend_package
return str(importlib.resources.files(comfyui_frontend_package) / "static")
except ImportError:
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n")
sys.exit(-1)
@classmethod @classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]: def parse_version_string(cls, value: str) -> tuple[str, str, str]:
""" """
@@ -160,7 +191,8 @@ class FrontendManager:
main error source might be request timeout or invalid URL. main error source might be request timeout or invalid URL.
""" """
if version_string == DEFAULT_VERSION_STRING: if version_string == DEFAULT_VERSION_STRING:
return cls.DEFAULT_FRONTEND_PATH check_frontend_version()
return cls.default_frontend_path()
repo_owner, repo_name, version = cls.parse_version_string(version_string) repo_owner, repo_name, version = cls.parse_version_string(version_string)
@@ -213,4 +245,5 @@ class FrontendManager:
except Exception as e: except Exception as e:
logging.error("Failed to initialize frontend: %s", e) logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.") logging.info("Falling back to the default frontend.")
return cls.DEFAULT_FRONTEND_PATH check_frontend_version()
return cls.default_frontend_path()

View File

@@ -82,3 +82,17 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
logger.addHandler(stdout_handler) logger.addHandler(stdout_handler)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
STARTUP_WARNINGS = []
def log_startup_warning(msg):
logging.warning(msg)
STARTUP_WARNINGS.append(msg)
def print_startup_warnings():
for s in STARTUP_WARNINGS:
logging.warning(s)
STARTUP_WARNINGS.clear()

View File

@@ -106,6 +106,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")

View File

@@ -9,6 +9,7 @@ import comfy.model_patcher
import comfy.model_management import comfy.model_management
import comfy.utils import comfy.utils
import comfy.clip_model import comfy.clip_model
import comfy.image_encoders.dino2
class Output: class Output:
def __getitem__(self, key): def __getitem__(self, key):
@@ -34,6 +35,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
image = torch.clip((255. * image), 0, 255).round() / 255.0 image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1]) return (image - mean.view([3,1,1])) / std.view([3,1,1])
IMAGE_ENCODERS = {
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
}
class ClipVisionModel(): class ClipVisionModel():
def __init__(self, json_config): def __init__(self, json_config):
with open(json_config) as f: with open(json_config) as f:
@@ -42,10 +49,11 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224) self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
self.load_device = comfy.model_management.text_encoder_device() self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device() offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval() self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
@@ -111,6 +119,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else: else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
elif "embeddings.patch_embeddings.projection.weight" in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
else: else:
return None return None

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Literal, TypedDict from typing import Literal, TypedDict
from typing_extensions import NotRequired
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
@@ -26,6 +27,7 @@ class IO(StrEnum):
BOOLEAN = "BOOLEAN" BOOLEAN = "BOOLEAN"
INT = "INT" INT = "INT"
FLOAT = "FLOAT" FLOAT = "FLOAT"
COMBO = "COMBO"
CONDITIONING = "CONDITIONING" CONDITIONING = "CONDITIONING"
SAMPLER = "SAMPLER" SAMPLER = "SAMPLER"
SIGMAS = "SIGMAS" SIGMAS = "SIGMAS"
@@ -66,6 +68,7 @@ class IO(StrEnum):
b = frozenset(value.split(",")) b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b)) return not (b.issubset(a) or a.issubset(b))
class RemoteInputOptions(TypedDict): class RemoteInputOptions(TypedDict):
route: str route: str
"""The route to the remote source.""" """The route to the remote source."""
@@ -80,6 +83,14 @@ class RemoteInputOptions(TypedDict):
refresh: int refresh: int
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed.""" """The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
class MultiSelectOptions(TypedDict):
placeholder: NotRequired[str]
"""The placeholder text to display in the multi-select widget when no items are selected."""
chip: NotRequired[bool]
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
class InputTypeOptions(TypedDict): class InputTypeOptions(TypedDict):
"""Provides type hinting for the return type of the INPUT_TYPES node function. """Provides type hinting for the return type of the INPUT_TYPES node function.
@@ -114,7 +125,7 @@ class InputTypeOptions(TypedDict):
# default: bool # default: bool
label_on: str label_on: str
"""The label to use in the UI when the bool is True (``BOOLEAN``)""" """The label to use in the UI when the bool is True (``BOOLEAN``)"""
label_on: str label_off: str
"""The label to use in the UI when the bool is False (``BOOLEAN``)""" """The label to use in the UI when the bool is False (``BOOLEAN``)"""
# class InputTypeString(InputTypeOptions): # class InputTypeString(InputTypeOptions):
# default: str # default: str
@@ -133,9 +144,22 @@ class InputTypeOptions(TypedDict):
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag. """Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
""" """
remote: RemoteInputOptions remote: RemoteInputOptions
"""Specifies the configuration for a remote input.""" """Specifies the configuration for a remote input.
Available after ComfyUI frontend v1.9.7
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
control_after_generate: bool control_after_generate: bool
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types.""" """Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
options: NotRequired[list[str | int | float]]
"""COMBO type only. Specifies the selectable options for the combo widget.
Prefer:
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
Over:
[["Option 1", "Option 2", "Option 3"]]
"""
multi_select: NotRequired[MultiSelectOptions]
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
class HiddenInputTypeDict(TypedDict): class HiddenInputTypeDict(TypedDict):

View File

@@ -0,0 +1,141 @@
import torch
from comfy.text_encoders.bert import BertAttention
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
class Dino2AttentionOutput(torch.nn.Module):
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
def forward(self, x):
return self.dense(x)
class Dino2AttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
def forward(self, x, mask, optimized_attention):
return self.output(self.attention(x, mask, optimized_attention))
class LayerScale(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
class SwiGLUFFN(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
in_features = out_features = dim
hidden_features = int(dim * 4)
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
def forward(self, x):
x = self.weights_in(x)
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
return self.weights_out(x)
class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, x, optimized_attention):
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
return x
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
super().__init__()
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
def forward(self, x, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
for i, l in enumerate(self.layer):
x = l(x, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class Dino2PatchEmbeddings(torch.nn.Module):
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
super().__init__()
self.projection = operations.Conv2d(
in_channels=num_channels,
out_channels=dim,
kernel_size=patch_size,
stride=patch_size,
bias=True,
dtype=dtype,
device=device
)
def forward(self, pixel_values):
return self.projection(pixel_values).flatten(2).transpose(1, 2)
class Dino2Embeddings(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
patch_size = 14
image_size = 518
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
def forward(self, pixel_values):
x = self.patch_embeddings(pixel_values)
# TODO: mask_token?
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
return x
class Dinov2Model(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
num_layers = config_dict["num_hidden_layers"]
dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values)
x, i = self.encoder(x, intermediate_output=intermediate_output)
x = self.layernorm(x)
pooled_output = x[:, 0, :]
return x, i, pooled_output, None

View File

@@ -0,0 +1,21 @@
{
"attention_probs_dropout_prob": 0.0,
"drop_path_rate": 0.0,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 1536,
"image_size": 518,
"initializer_range": 0.02,
"layer_norm_eps": 1e-06,
"layerscale_value": 1.0,
"mlp_ratio": 4,
"model_type": "dinov2",
"num_attention_heads": 24,
"num_channels": 3,
"num_hidden_layers": 40,
"patch_size": 14,
"qkv_bias": true,
"use_swiglu_ffn": true,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225]
}

View File

@@ -688,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
if len(sigmas) <= 1: if len(sigmas) <= 1:
return x return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp() sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg() t_fn = lambda sigma: sigma.log().neg()
@@ -762,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
if solver_type not in {'heun', 'midpoint'}: if solver_type not in {'heun', 'midpoint'}:
raise ValueError('solver_type must be \'heun\' or \'midpoint\'') raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
old_denoised = None old_denoised = None
@@ -808,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
if len(sigmas) <= 1: if len(sigmas) <= 1:
return x return x
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
denoised_1, denoised_2 = None, None denoised_1, denoised_2 = None, None
@@ -858,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1: if len(sigmas) <= 1:
return x return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@@ -867,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1: if len(sigmas) <= 1:
return x return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@@ -876,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
if len(sigmas) <= 1: if len(sigmas) <= 1:
return x return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
@@ -1366,3 +1366,59 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
x = x + d_bar * dt x = x + d_bar * dt
old_d = d old_d = d
return x return x
@torch.no_grad()
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
"""
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
def default_noise_scaler(sigma):
return sigma * ((sigma ** 0.3).exp() + 10.0)
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
num_integration_points = 200.0
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
old_denoised = None
old_denoised_d = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
stage_used = min(max_stage, i + 1)
if sigmas[i + 1] == 0:
x = denoised
elif stage_used == 1:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
else:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
dt = sigmas[i + 1] - sigmas[i]
sigma_step_size = -dt / num_integration_points
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
scaled_pos = noise_scaler(sigma_pos)
# Stage 2
s = torch.sum(1 / scaled_pos) * sigma_step_size
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
if stage_used >= 3:
# Stage 3
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
old_denoised_d = denoised_d
if s_noise != 0 and sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
old_denoised = denoised
return x

View File

@@ -19,6 +19,10 @@
import torch import torch
from torch import nn from torch import nn
from torch.autograd import Function from torch.autograd import Function
import comfy.ops
ops = comfy.ops.disable_weight_init
class vector_quantize(Function): class vector_quantize(Function):
@staticmethod @staticmethod
@@ -121,15 +125,15 @@ class ResBlock(nn.Module):
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.depthwise = nn.Sequential( self.depthwise = nn.Sequential(
nn.ReplicationPad2d(1), nn.ReplicationPad2d(1),
nn.Conv2d(c, c, kernel_size=3, groups=c) ops.Conv2d(c, c, kernel_size=3, groups=c)
) )
# channelwise # channelwise
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential( self.channelwise = nn.Sequential(
nn.Linear(c, c_hidden), ops.Linear(c, c_hidden),
nn.GELU(), nn.GELU(),
nn.Linear(c_hidden, c), ops.Linear(c_hidden, c),
) )
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
@@ -171,16 +175,16 @@ class StageA(nn.Module):
# Encoder blocks # Encoder blocks
self.in_block = nn.Sequential( self.in_block = nn.Sequential(
nn.PixelUnshuffle(2), nn.PixelUnshuffle(2),
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) ops.Conv2d(3 * 4, c_levels[0], kernel_size=1)
) )
down_blocks = [] down_blocks = []
for i in range(levels): for i in range(levels):
if i > 0: if i > 0:
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
block = ResBlock(c_levels[i], c_levels[i] * 4) block = ResBlock(c_levels[i], c_levels[i] * 4)
down_blocks.append(block) down_blocks.append(block)
down_blocks.append(nn.Sequential( down_blocks.append(nn.Sequential(
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
)) ))
self.down_blocks = nn.Sequential(*down_blocks) self.down_blocks = nn.Sequential(*down_blocks)
@@ -191,7 +195,7 @@ class StageA(nn.Module):
# Decoder blocks # Decoder blocks
up_blocks = [nn.Sequential( up_blocks = [nn.Sequential(
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) ops.Conv2d(c_latent, c_levels[-1], kernel_size=1)
)] )]
for i in range(levels): for i in range(levels):
for j in range(bottleneck_blocks if i == 0 else 1): for j in range(bottleneck_blocks if i == 0 else 1):
@@ -199,11 +203,11 @@ class StageA(nn.Module):
up_blocks.append(block) up_blocks.append(block)
if i < levels - 1: if i < levels - 1:
up_blocks.append( up_blocks.append(
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
padding=1)) padding=1))
self.up_blocks = nn.Sequential(*up_blocks) self.up_blocks = nn.Sequential(*up_blocks)
self.out_block = nn.Sequential( self.out_block = nn.Sequential(
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
nn.PixelShuffle(2), nn.PixelShuffle(2),
) )
@@ -232,17 +236,17 @@ class Discriminator(nn.Module):
super().__init__() super().__init__()
d = max(depth - 3, 3) d = max(depth - 3, 3)
layers = [ layers = [
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
] ]
for i in range(depth - 1): for i in range(depth - 1):
c_in = c_hidden // (2 ** max((d - i), 0)) c_in = c_hidden // (2 ** max((d - i), 0))
c_out = c_hidden // (2 ** max((d - 1 - i), 0)) c_out = c_hidden // (2 ** max((d - 1 - i), 0))
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
layers.append(nn.InstanceNorm2d(c_out)) layers.append(nn.InstanceNorm2d(c_out))
layers.append(nn.LeakyReLU(0.2)) layers.append(nn.LeakyReLU(0.2))
self.encoder = nn.Sequential(*layers) self.encoder = nn.Sequential(*layers)
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
self.logits = nn.Sigmoid() self.logits = nn.Sigmoid()
def forward(self, x, cond=None): def forward(self, x, cond=None):

View File

@@ -19,6 +19,9 @@ import torch
import torchvision import torchvision
from torch import nn from torch import nn
import comfy.ops
ops = comfy.ops.disable_weight_init
# EfficientNet # EfficientNet
class EfficientNetEncoder(nn.Module): class EfficientNetEncoder(nn.Module):
@@ -26,7 +29,7 @@ class EfficientNetEncoder(nn.Module):
super().__init__() super().__init__()
self.backbone = torchvision.models.efficientnet_v2_s().features.eval() self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
self.mapper = nn.Sequential( self.mapper = nn.Sequential(
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), ops.Conv2d(1280, c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
) )
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406])) self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
@@ -34,7 +37,7 @@ class EfficientNetEncoder(nn.Module):
def forward(self, x): def forward(self, x):
x = x * 0.5 + 0.5 x = x * 0.5 + 0.5
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1]) x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype)
o = self.mapper(self.backbone(x)) o = self.mapper(self.backbone(x))
return o return o
@@ -44,39 +47,39 @@ class Previewer(nn.Module):
def __init__(self, c_in=16, c_hidden=512, c_out=3): def __init__(self, c_in=16, c_hidden=512, c_out=3):
super().__init__() super().__init__()
self.blocks = nn.Sequential( self.blocks = nn.Sequential(
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden), nn.BatchNorm2d(c_hidden),
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden), nn.BatchNorm2d(c_hidden),
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden // 2), nn.BatchNorm2d(c_hidden // 2),
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden // 2), nn.BatchNorm2d(c_hidden // 2),
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden // 4), nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden // 4), nn.BatchNorm2d(c_hidden // 4),
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden // 4), nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden // 4), nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), ops.Conv2d(c_hidden // 4, c_out, kernel_size=1),
) )
def forward(self, x): def forward(self, x):

View File

@@ -105,7 +105,9 @@ class Modulation(nn.Module):
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple: def forward(self, vec: Tensor) -> tuple:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) if vec.ndim == 2:
vec = vec[:, None, :]
out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
return ( return (
ModulationOut(*out[:3]), ModulationOut(*out[:3]),
@@ -113,6 +115,20 @@ class Modulation(nn.Module):
) )
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
if modulation_dims is None:
if m_add is not None:
return tensor * m_mult + m_add
else:
return tensor * m_mult
else:
for d in modulation_dims:
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
if m_add is not None:
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
return tensor
class DoubleStreamBlock(nn.Module): class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
@@ -143,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
) )
self.flipped_img_txt = flipped_img_txt self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None): def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
img_mod1, img_mod2 = self.img_mod(vec) img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention # prepare image for attention
img_modulated = self.img_norm1(img) img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
img_qkv = self.img_attn.qkv(img_modulated) img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention # prepare txt for attention
txt_modulated = self.txt_norm1(txt) txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
txt_qkv = self.txt_attn.qkv(txt_modulated) txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
@@ -179,12 +195,12 @@ class DoubleStreamBlock(nn.Module):
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks # calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn) img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
# calculate the txt bloks # calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
if txt.dtype == torch.float16: if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
@@ -228,9 +244,9 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh") self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor: def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
mod, _ = self.modulation(vec) mod, _ = self.modulation(vec)
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
@@ -239,7 +255,7 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask) attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16: if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x return x
@@ -252,8 +268,11 @@ class LastLayer(nn.Module):
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
def forward(self, x: Tensor, vec: Tensor) -> Tensor: def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) if vec.ndim == 2:
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] vec = vec[:, None, :]
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims)
x = self.linear(x) x = self.linear(x)
return x return x

View File

@@ -10,10 +10,11 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q_shape = q.shape q_shape = q.shape
k_shape = k.shape k_shape = k.shape
q = q.float().reshape(*q.shape[:-1], -1, 1, 2) if pe is not None:
k = k.float().reshape(*k.shape[:-1], -1, 1, 2) q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1] heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
@@ -36,8 +37,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -115,8 +115,11 @@ class Flux(nn.Module):
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt) txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1) if img_ids is not None:
pe = self.pe_embedder(ids) ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
else:
pe = None
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks): for i, block in enumerate(self.double_blocks):

View File

@@ -227,6 +227,7 @@ class HunyuanVideo(nn.Module):
timesteps: Tensor, timesteps: Tensor,
y: Tensor, y: Tensor,
guidance: Tensor = None, guidance: Tensor = None,
guiding_frame_index=None,
control=None, control=None,
transformer_options={}, transformer_options={},
) -> Tensor: ) -> Tensor:
@@ -237,7 +238,17 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img) img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
modulation_dims_txt = [(0, None, 1)]
else:
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
modulation_dims = None
modulation_dims_txt = None
if self.params.guidance_embed: if self.params.guidance_embed:
if guidance is not None: if guidance is not None:
@@ -264,14 +275,14 @@ class HunyuanVideo(nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"]) out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
return out return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
txt = out["txt"] txt = out["txt"]
img = out["img"] img = out["img"]
else: else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
if control is not None: # Controlnet if control is not None: # Controlnet
control_i = control.get("input") control_i = control.get("input")
@@ -286,13 +297,13 @@ class HunyuanVideo(nn.Module):
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"]) out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
return out return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap}) out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
img = out["img"] img = out["img"]
else: else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
if control is not None: # Controlnet if control is not None: # Controlnet
control_o = control.get("output") control_o = control.get("output")
@@ -303,7 +314,7 @@ class HunyuanVideo(nn.Module):
img = img[:, : img_len] img = img[:, : img_len]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
shape = initial_shape[-3:] shape = initial_shape[-3:]
for i in range(len(shape)): for i in range(len(shape)):
@@ -313,7 +324,7 @@ class HunyuanVideo(nn.Module):
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
return img return img
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs): def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape bs, c, t, h, w = x.shape
patch_size = self.patch_size patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
@@ -325,5 +336,5 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options) out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
return out return out

View File

@@ -24,6 +24,13 @@ if model_management.sage_attention_enabled():
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
exit(-1) exit(-1)
if model_management.flash_attention_enabled():
try:
from flash_attn import flash_attn_func
except ModuleNotFoundError:
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
from comfy.cli_args import args from comfy.cli_args import args
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@@ -496,6 +503,63 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
return out return out
try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
@flash_attn_wrapper.register_fake
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
# Output shape is the same as q
return q.new_empty(q.shape)
except AttributeError as error:
FLASH_ATTN_ERROR = error
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
try:
assert mask is None
out = flash_attn_wrapper(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
dropout_p=0.0,
causal=False,
).transpose(1, 2)
except Exception as e:
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
optimized_attention = attention_basic optimized_attention = attention_basic
if model_management.sage_attention_enabled(): if model_management.sage_attention_enabled():
@@ -504,6 +568,9 @@ if model_management.sage_attention_enabled():
elif model_management.xformers_enabled(): elif model_management.xformers_enabled():
logging.info("Using xformers attention") logging.info("Using xformers attention")
optimized_attention = attention_xformers optimized_attention = attention_xformers
elif model_management.flash_attention_enabled():
logging.info("Using Flash Attention")
optimized_attention = attention_flash
elif model_management.pytorch_attention_enabled(): elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention") logging.info("Using pytorch attention")
optimized_attention = attention_pytorch optimized_attention = attention_pytorch

View File

@@ -384,6 +384,7 @@ class WanModel(torch.nn.Module):
context, context,
clip_fea=None, clip_fea=None,
freqs=None, freqs=None,
transformer_options={},
): ):
r""" r"""
Forward pass through the diffusion model Forward pass through the diffusion model
@@ -423,14 +424,18 @@ class WanModel(torch.nn.Module):
context_clip = self.img_emb(clip_fea) # bs x 257 x dim context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1) context = torch.concat([context_clip, context], dim=1)
# arguments patches_replace = transformer_options.get("patches_replace", {})
kwargs = dict( blocks_replace = patches_replace.get("dit", {})
e=e0, for i, block in enumerate(self.blocks):
freqs=freqs, if ("double_block", i) in blocks_replace:
context=context) def block_wrap(args):
out = {}
for block in self.blocks: out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
x = block(x, **kwargs) return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
# head # head
x = self.head(x, e) x = self.head(x, e)
@@ -439,7 +444,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return x return x
def forward(self, x, timestep, context, clip_fea=None, **kwargs): def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
bs, c, t, h, w = x.shape bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size patch_size = self.patch_size
@@ -453,7 +458,7 @@ class WanModel(torch.nn.Module):
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2) freqs = self.rope_embedder(img_ids).movedim(1, 2)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w] return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes): def unpatchify(self, x, grid_sizes):
r""" r"""

View File

@@ -16,6 +16,7 @@
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
from __future__ import annotations
import torch import torch
import logging import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@@ -104,11 +105,11 @@ class BaseModel(torch.nn.Module):
self.model_config = model_config self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype self.manual_cast_dtype = model_config.manual_cast_dtype
self.device = device self.device = device
self.current_patcher: 'ModelPatcher' = None self.current_patcher: ModelPatcher = None
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None: if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None) fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
else: else:
operations = model_config.custom_operations operations = model_config.custom_operations
@@ -128,6 +129,7 @@ class BaseModel(torch.nn.Module):
logging.info("model_type {}".format(model_type.name)) logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels)) logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor self.memory_usage_factor = model_config.memory_usage_factor
self.zipper_initialized = False
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor( return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -137,6 +139,16 @@ class BaseModel(torch.nn.Module):
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) ).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
# handle lowvram zipper initialization, if required
if self.model_lowvram and not self.zipper_initialized:
if self.current_patcher:
self.zipper_initialized = True
with self.current_patcher.use_ejected():
loading = self.current_patcher._load_list_lowvram_only()
return self._apply_model_inner(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def _apply_model_inner(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t sigma = t
xc = self.model_sampling.calculate_input(sigma, x) xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None: if c_concat is not None:
@@ -898,20 +910,31 @@ class HunyuanVideo(BaseModel):
guidance = kwargs.get("guidance", 6.0) guidance = kwargs.get("guidance", 6.0)
if guidance is not None: if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
guiding_frame_index = kwargs.get("guiding_frame_index", None)
if guiding_frame_index is not None:
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
return out return out
def scale_latent_inpaint(self, latent_image, **kwargs):
return latent_image
class HunyuanVideoI2V(HunyuanVideo): class HunyuanVideoI2V(HunyuanVideo):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.concat_keys = ("concat_image", "mask_inverted") self.concat_keys = ("concat_image", "mask_inverted")
def scale_latent_inpaint(self, latent_image, **kwargs):
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
class HunyuanVideoSkyreelsI2V(HunyuanVideo): class HunyuanVideoSkyreelsI2V(HunyuanVideo):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.concat_keys = ("concat_image",) self.concat_keys = ("concat_image",)
def scale_latent_inpaint(self, latent_image, **kwargs):
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
class CosmosVideo(BaseModel): class CosmosVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
@@ -962,11 +985,11 @@ class WAN21(BaseModel):
self.image_to_video = image_to_video self.image_to_video = image_to_video
def concat_cond(self, **kwargs): def concat_cond(self, **kwargs):
if not self.image_to_video: noise = kwargs.get("noise", None)
if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]:
return None return None
image = kwargs.get("concat_latent_image", None) image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"] device = kwargs["device"]
if image is None: if image is None:
@@ -976,6 +999,9 @@ class WAN21(BaseModel):
image = self.process_latent_in(image) image = self.process_latent_in(image)
image = utils.resize_to_batch_size(image, noise.shape[0]) image = utils.resize_to_batch_size(image, noise.shape[0])
if not self.image_to_video:
return image
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None: if mask is None:
mask = torch.zeros_like(noise)[:, :4] mask = torch.zeros_like(noise)[:, :4]

View File

@@ -471,6 +471,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
model_config.scaled_fp8 = scaled_fp8_weight.dtype model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32: if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn model_config.scaled_fp8 = torch.float8_e4m3fn
if scaled_fp8_weight.nelement() == 2:
model_config.optimizations["fp8"] = False
else:
model_config.optimizations["fp8"] = True
return model_config return model_config

View File

@@ -186,12 +186,21 @@ def get_total_memory(dev=None, torch_total_too=False):
else: else:
return mem_total return mem_total
def mac_version():
try:
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
except:
return None
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try: try:
logging.info("pytorch version: {}".format(torch_version)) logging.info("pytorch version: {}".format(torch_version))
mac_ver = mac_version()
if mac_ver is not None:
logging.info("Mac Version {}".format(mac_ver))
except: except:
pass pass
@@ -581,7 +590,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
loaded_memory = loaded_model.model_loaded_memory() loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory) lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
@@ -921,6 +930,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
def sage_attention_enabled(): def sage_attention_enabled():
return args.use_sage_attention return args.use_sage_attention
def flash_attention_enabled():
return args.use_flash_attention
def xformers_enabled(): def xformers_enabled():
global directml_enabled global directml_enabled
global cpu_state global cpu_state
@@ -969,12 +981,6 @@ def pytorch_attention_flash_attention():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
return False return False
def mac_version():
try:
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
except:
return None
def force_upcast_attention_dtype(): def force_upcast_attention_dtype():
upcast = args.force_upcast_attention upcast = args.force_upcast_attention

View File

@@ -17,7 +17,7 @@
""" """
from __future__ import annotations from __future__ import annotations
from typing import Optional, Callable from typing import Optional, Callable, TYPE_CHECKING
import torch import torch
import copy import copy
import inspect import inspect
@@ -26,6 +26,7 @@ import uuid
import collections import collections
import math import math
import comfy.ops
import comfy.utils import comfy.utils
import comfy.float import comfy.float
import comfy.model_management import comfy.model_management
@@ -34,6 +35,9 @@ import comfy.hooks
import comfy.patcher_extension import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
from comfy.comfy_types import UnetWrapperFunction from comfy.comfy_types import UnetWrapperFunction
if TYPE_CHECKING:
from comfy.model_base import BaseModel
def string_to_seed(data): def string_to_seed(data):
crc = 0xFFFFFFFF crc = 0xFFFFFFFF
@@ -201,7 +205,7 @@ class MemoryCounter:
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size self.size = size
self.model = model self.model: BaseModel = model
if not hasattr(self.model, 'device'): if not hasattr(self.model, 'device'):
logging.debug("Model doesn't have a device attribute.") logging.debug("Model doesn't have a device attribute.")
self.model.device = offload_device self.model.device = offload_device
@@ -568,6 +572,14 @@ class ModelPatcher:
else: else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def _zipper_dict_lowvram_only(self):
loading = self._load_list_lowvram_only()
def _load_list_lowvram_only(self):
loading = self._load_list()
return [x for x in loading if hasattr(x[2], "prev_comfy_cast_weights")]
def _load_list(self): def _load_list(self):
loading = [] loading = []
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
@@ -583,6 +595,35 @@ class ModelPatcher:
loading.append((comfy.model_management.module_size(m), n, m, params)) loading.append((comfy.model_management.module_size(m), n, m, params))
return loading return loading
def prepare_teeth(self):
ordered_list = self._load_list_lowvram_only()
prev_i = None
next_i = None
# first, create teeth on modules in list
for l in ordered_list:
m: comfy.ops.CastWeightBiasOp = l[2]
m.init_tooth(self.load_device, self.offload_device, l[1])
# create teeth linked list
for i in range(len(ordered_list)):
if i+1 == len(ordered_list):
next_i = None
else:
next_i = i+1
m: comfy.ops.CastWeightBiasOp = ordered_list[i][2]
if prev_i is not None:
m.zipper_tooth.prev_tooth = ordered_list[prev_i][2].zipper_tooth
else:
m.zipper_tooth.start = True
if next_i is not None:
m.zipper_tooth.next_tooth = ordered_list[next_i][2].zipper_tooth
prev_i = i
def clean_teeth(self):
ordered_list = self._load_list_lowvram_only()
for l in ordered_list:
m: comfy.ops.CastWeightBiasOp = l[2]
m.clean_tooth()
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
with self.use_ejected(): with self.use_ejected():
self.unpatch_hooks() self.unpatch_hooks()
@@ -591,6 +632,8 @@ class ModelPatcher:
lowvram_counter = 0 lowvram_counter = 0
loading = self._load_list() loading = self._load_list()
logging.info(f"total size of _load_list: {sum([x[0] for x in loading])}")
load_completely = [] load_completely = []
loading.sort(reverse=True) loading.sort(reverse=True)
for x in loading: for x in loading:
@@ -672,6 +715,7 @@ class ModelPatcher:
if lowvram_counter > 0: if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True self.model.model_lowvram = True
self.model.zipper_initialized = False
else: else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False self.model.model_lowvram = False
@@ -684,6 +728,9 @@ class ModelPatcher:
self.model.model_loaded_weight_memory = mem_counter self.model.model_loaded_weight_memory = mem_counter
self.model.current_weight_patches_uuid = self.patches_uuid self.model.current_weight_patches_uuid = self.patches_uuid
if self.model.model_lowvram:
self.prepare_teeth()
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
@@ -715,6 +762,7 @@ class ModelPatcher:
move_weight_functions(m, device_to) move_weight_functions(m, device_to)
wipe_lowvram_weight(m) wipe_lowvram_weight(m)
self.clean_teeth()
self.model.model_lowvram = False self.model.model_lowvram = False
self.model.lowvram_patch_counter = 0 self.model.lowvram_patch_counter = 0
@@ -747,6 +795,7 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0): def partially_unload(self, device_to, memory_to_free=0):
with self.use_ejected(): with self.use_ejected():
hooks_unpatched = False
memory_freed = 0 memory_freed = 0
patch_counter = 0 patch_counter = 0
unload_list = self._load_list() unload_list = self._load_list()
@@ -770,6 +819,10 @@ class ModelPatcher:
move_weight = False move_weight = False
break break
if not hooks_unpatched:
self.unpatch_hooks()
hooks_unpatched = True
if bk.inplace_update: if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight) comfy.utils.copy_to_param(self.model, key, bk.weight)
else: else:
@@ -799,8 +852,10 @@ class ModelPatcher:
logging.debug("freed {}".format(n)) logging.debug("freed {}".format(n))
self.model.model_lowvram = True self.model.model_lowvram = True
self.model.zipper_initialized = False
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed self.model.model_loaded_weight_memory -= memory_freed
self.prepare_teeth()
return memory_freed return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
@@ -1089,7 +1144,6 @@ class ModelPatcher:
def patch_hooks(self, hooks: comfy.hooks.HookGroup): def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected(): with self.use_ejected():
self.unpatch_hooks()
if hooks is not None: if hooks is not None:
model_sd_keys = list(self.model_state_dict().keys()) model_sd_keys = list(self.model_state_dict().keys())
memory_counter = None memory_counter = None
@@ -1100,12 +1154,16 @@ class ModelPatcher:
# if have cached weights for hooks, use it # if have cached weights for hooks, use it
cached_weights = self.cached_hook_patches.get(hooks, None) cached_weights = self.cached_hook_patches.get(hooks, None)
if cached_weights is not None: if cached_weights is not None:
model_sd_keys_set = set(model_sd_keys)
for key in cached_weights: for key in cached_weights:
if key not in model_sd_keys: if key not in model_sd_keys:
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}") logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
continue continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter) self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
model_sd_keys_set.remove(key)
self.unpatch_hooks(model_sd_keys_set)
else: else:
self.unpatch_hooks()
relevant_patches = self.get_combined_hook_patches(hooks=hooks) relevant_patches = self.get_combined_hook_patches(hooks=hooks)
original_weights = None original_weights = None
if len(relevant_patches) > 0: if len(relevant_patches) > 0:
@@ -1116,6 +1174,8 @@ class ModelPatcher:
continue continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights, self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter) memory_counter=memory_counter)
else:
self.unpatch_hooks()
self.current_hooks = hooks self.current_hooks = hooks
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
@@ -1172,17 +1232,23 @@ class ModelPatcher:
del out_weight del out_weight
del weight del weight
def unpatch_hooks(self) -> None: def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
with self.use_ejected(): with self.use_ejected():
if len(self.hook_backup) == 0: if len(self.hook_backup) == 0:
self.current_hooks = None self.current_hooks = None
return return
keys = list(self.hook_backup.keys()) keys = list(self.hook_backup.keys())
for k in keys: if whitelist_keys_set:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) for k in keys:
if k in whitelist_keys_set:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.pop(k)
else:
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.clear() self.hook_backup.clear()
self.current_hooks = None self.current_hooks = None
def clean_hooks(self): def clean_hooks(self):
self.unpatch_hooks() self.unpatch_hooks()

View File

@@ -16,7 +16,9 @@
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
from __future__ import annotations
import torch import torch
import logging
import comfy.model_management import comfy.model_management
from comfy.cli_args import args, PerformanceFeature from comfy.cli_args import args, PerformanceFeature
import comfy.float import comfy.float
@@ -55,6 +57,79 @@ class CastWeightBiasOp:
comfy_cast_weights = False comfy_cast_weights = False
weight_function = [] weight_function = []
bias_function = [] bias_function = []
zipper_init: dict = None
zipper_tooth: ZipperTooth = None
_zipper_tooth: ZipperTooth = None
def init_tooth(self, load_device, offload_device, key: str=None):
if self.zipper_tooth:
self.clean_tooth()
self.zipper_tooth = ZipperTooth(self, load_device, offload_device, key)
def clean_tooth(self):
if self.zipper_tooth:
del self.zipper_tooth
self.zipper_tooth = None
def connect_teeth(self):
if self.zipper_init is not None:
self.zipper_init[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
self.zipper_dict["prev_zipper_key"] = self.zipper_key
# def zipper_connect(self):
# if self.zipper_dict is not None:
# self.zipper_dict[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
# self.zipper_dict["prev_zipper_key"] = self.zipper_key
class ZipperTooth:
def __init__(self, op: CastWeightBiasOp, load_device, offload_device, key: str=None):
self.op = op
self.key: str = key
self.weight_preloaded: torch.Tensor = None
self.bias_preloaded: torch.Tensor = None
self.load_device = load_device
self.offload_device = offload_device
self.start = False
self.prev_tooth: ZipperTooth = None
self.next_tooth: ZipperTooth = None
def get_bias_weight(self, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
try:
if self.start:
return cast_bias_weight(self.op, input, dtype, device, bias_dtype)
return self.weight_preloaded, self.bias_preloaded
finally:
# if self.prev_tooth:
# self.prev_tooth.offload_previous(0)
self.next_tooth.preload_next(0, input, dtype, device, bias_dtype)
def preload_next(self, teeth_count=1, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
# TODO: queue load of tensors
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
non_blocking = comfy.model_management.device_supports_non_blocking(self.load_device)
if self.op.bias is not None:
self.bias_preloaded = comfy.model_management.cast_to(self.op.bias, bias_dtype, device, non_blocking=non_blocking)
self.weight_preloaded = comfy.model_management.cast_to(self.op.weight, dtype, device, non_blocking=non_blocking)
if self.next_tooth and teeth_count:
self.next_tooth.preload_next(teeth_count-1)
def offload_previous(self, teeth_count=1):
# TODO: queue offload of tensors
self.weight_preloaded = None
self.bias_preloaded = None
if self.prev_tooth and teeth_count:
self.prev_tooth.offload_previous(teeth_count-1)
class disable_weight_init: class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp): class Linear(torch.nn.Linear, CastWeightBiasOp):
@@ -62,7 +137,11 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) #if self.zipper_init:
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@@ -76,7 +155,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias) return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@@ -90,7 +172,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias) return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@@ -104,7 +189,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias) return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@@ -118,7 +206,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@@ -133,7 +224,10 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.weight is not None: if self.weight is not None:
weight, bias = cast_bias_weight(self, input) if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
else: else:
weight = None weight = None
bias = None bias = None
@@ -155,7 +249,10 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size, input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation) num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input) if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d( return torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding, input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
@@ -176,7 +273,10 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size, input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation) num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input) if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d( return torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding, input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
@@ -196,7 +296,10 @@ class disable_weight_init:
output_dtype = out_dtype output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(device=input.device, dtype=out_dtype)
else:
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@@ -308,6 +411,7 @@ class fp8_ops(manual_cast):
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
class scaled_fp8_op(manual_cast): class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear): class Linear(manual_cast.Linear):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@@ -358,7 +462,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None: if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8) return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
if ( if (
fp8_compute and fp8_compute and

View File

@@ -6,6 +6,7 @@ if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase from comfy.controlnet import ControlBase
from comfy.ops import CastWeightBiasOp
import torch import torch
from functools import partial from functools import partial
import collections import collections
@@ -18,6 +19,7 @@ import comfy.patcher_extension
import comfy.hooks import comfy.hooks
import scipy.stats import scipy.stats
import numpy import numpy
import comfy.ops
def add_area_dims(area, num_dims): def add_area_dims(area, num_dims):
@@ -360,15 +362,38 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns denoised #Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def sampling_function(model: BaseModel, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None uncond_ = None
else: else:
uncond_ = uncond uncond_ = uncond
do_cleanup = False
if "weight_zipper" not in model_options:
do_cleanup = True
#zipper_dict = {}
model_options["weight_zipper"] = True
loaded_modules = model.current_patcher._load_list_lowvram_only()
low_m = [x for x in loaded_modules if hasattr(x[2], "prev_comfy_cast_weights")]
sum_m = sum([x[0] for x in low_m])
for l in loaded_modules:
m: CastWeightBiasOp = l[2]
if hasattr(m, "comfy_cast_weights"):
m.zipper_tooth = comfy.ops.ZipperTooth
#m.zipper_dict = zipper_dict
m.zipper_key = l[1]
conds = [cond, uncond_] conds = [cond, uncond_]
out = calc_cond_batch(model, conds, x, timestep, model_options) out = calc_cond_batch(model, conds, x, timestep, model_options)
if do_cleanup:
zzz = 20
for l in loaded_modules:
m: CastWeightBiasOp = l[2]
if hasattr(l[2], "comfy_cast_weights"):
#m.zipper_dict = None
m.zipper_key = None
for fn in model_options.get("sampler_pre_cfg_function", []): for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep, args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
"input": x, "sigma": timestep, "model": model, "model_options": model_options} "input": x, "sigma": timestep, "model": model, "model_options": model_options}
@@ -710,7 +735,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation"] "gradient_estimation", "er_sde"]
class KSAMPLER(Sampler): class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}): def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

View File

@@ -440,6 +440,10 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels): def vae_encode_crop_pixels(self, pixels):
downscale_ratio = self.spacial_compression_encode() downscale_ratio = self.spacial_compression_encode()
@@ -495,6 +499,7 @@ class VAE:
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in): def decode(self, samples_in):
self.throw_exception_if_invalid()
pixel_samples = None pixel_samples = None
try: try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
@@ -525,6 +530,7 @@ class VAE:
return pixel_samples return pixel_samples
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used) model_management.load_models_gpu([self.patcher], memory_required=memory_used)
dims = samples.ndim - 2 dims = samples.ndim - 2
@@ -553,6 +559,7 @@ class VAE:
return output.movedim(1, -1) return output.movedim(1, -1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1) pixel_samples = pixel_samples.movedim(-1, 1)
if self.latent_dim == 3 and pixel_samples.ndim < 5: if self.latent_dim == 3 and pixel_samples.ndim < 5:
@@ -585,6 +592,7 @@ class VAE:
return samples return samples
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
dims = self.latent_dim dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1) pixel_samples = pixel_samples.movedim(-1, 1)
@@ -899,7 +907,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None: if model_config is None:
return None logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
if diffusion_model is None:
return None
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
unet_weight_dtype = list(model_config.supported_inference_dtypes) unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None: if model_config.scaled_fp8 is not None:

View File

@@ -228,6 +228,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if pad_extra > 0: if pad_extra > 0:
padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32) padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32)
tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1) tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1)
attention_mask = attention_mask + [0] * pad_extra
embeds_out.append(tokens_embed) embeds_out.append(tokens_embed)
attention_masks.append(attention_mask) attention_masks.append(attention_mask)

View File

@@ -931,7 +931,7 @@ class WAN21_T2V(supported_models_base.BASE):
memory_usage_factor = 1.0 memory_usage_factor = 1.0
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."] vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]

View File

@@ -42,7 +42,7 @@ class HunyuanVideoTokenizer:
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, **kwargs): def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
out = {} out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
@@ -56,7 +56,7 @@ class HunyuanVideoTokenizer:
for i in range(len(r)): for i in range(len(r)):
if r[i][0] == 128257: if r[i][0] == 128257:
if image_embeds is not None and embed_count < image_embeds.shape[0]: if image_embeds is not None and embed_count < image_embeds.shape[0]:
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:] r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image", "image_interleave": image_interleave},) + r[i][1:]
embed_count += 1 embed_count += 1
out["llama"] = llama_text_tokens out["llama"] = llama_text_tokens
return out return out
@@ -92,10 +92,10 @@ class HunyuanVideoClipModel(torch.nn.Module):
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama) llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
template_end = 0 template_end = 0
image_start = None extra_template_end = 0
image_end = None
extra_sizes = 0 extra_sizes = 0
user_end = 9999999999999 user_end = 9999999999999
images = []
tok_pairs = token_weight_pairs_llama[0] tok_pairs = token_weight_pairs_llama[0]
for i, v in enumerate(tok_pairs): for i, v in enumerate(tok_pairs):
@@ -112,22 +112,28 @@ class HunyuanVideoClipModel(torch.nn.Module):
else: else:
if elem.get("original_type") == "image": if elem.get("original_type") == "image":
elem_size = elem.get("data").shape[0] elem_size = elem.get("data").shape[0]
if image_start is None: if template_end > 0:
if user_end == -1:
extra_template_end += elem_size - 1
else:
image_start = i + extra_sizes image_start = i + extra_sizes
image_end = i + elem_size + extra_sizes image_end = i + elem_size + extra_sizes
extra_sizes += elem_size - 1 images.append((image_start, image_end, elem.get("image_interleave", 1)))
extra_sizes += elem_size - 1
if llama_out.shape[1] > (template_end + 2): if llama_out.shape[1] > (template_end + 2):
if tok_pairs[template_end + 1][0] == 271: if tok_pairs[template_end + 1][0] == 271:
template_end += 2 template_end += 2
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes] llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes] llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]): if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
if image_start is not None: if len(images) > 0:
image_output = llama_out[:, image_start: image_end] out = []
llama_output = torch.cat([image_output[:, ::2], llama_output], dim=1) for i in images:
out.append(llama_out[:, i[0]: i[1]: i[2]])
llama_output = torch.cat(out + [llama_output], dim=1)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return llama_output, l_pooled, llama_extra_out return llama_output, l_pooled, llama_extra_out

View File

@@ -57,17 +57,17 @@ class TextEncodeHunyuanVideo_ImageToVideo:
"clip": ("CLIP", ), "clip": ("CLIP", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ), "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
}} }}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "advanced/conditioning" CATEGORY = "advanced/conditioning"
def encode(self, clip, clip_vision_output, prompt): def encode(self, clip, clip_vision_output, prompt, image_interleave):
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected) tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
return (clip.encode_from_tokens_scheduled(tokens), ) return (clip.encode_from_tokens_scheduled(tokens), )
class HunyuanImageToVideo: class HunyuanImageToVideo:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@@ -77,6 +77,7 @@ class HunyuanImageToVideo:
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
}, },
"optional": {"start_image": ("IMAGE", ), "optional": {"start_image": ("IMAGE", ),
}} }}
@@ -87,8 +88,10 @@ class HunyuanImageToVideo:
CATEGORY = "conditioning/video_models" CATEGORY = "conditioning/video_models"
def encode(self, positive, vae, width, height, length, batch_size, start_image=None): def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
out_latent = {}
if start_image is not None: if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
@@ -96,13 +99,20 @@ class HunyuanImageToVideo:
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) if guidance_type == "v1 (concat)":
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
else:
cond = {'guiding_frame_index': 0}
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
out_latent["noise_mask"] = mask
positive = node_helpers.conditioning_set_values(positive, cond)
out_latent = {}
out_latent["samples"] = latent out_latent["samples"] = latent
return (positive, out_latent) return (positive, out_latent)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,

View File

@@ -19,8 +19,6 @@ class Load3D():
"image": ("LOAD_3D", {}), "image": ("LOAD_3D", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}} }}
RETURN_TYPES = ("IMAGE", "MASK", "STRING") RETURN_TYPES = ("IMAGE", "MASK", "STRING")
@@ -55,8 +53,6 @@ class Load3DAnimation():
"image": ("LOAD_3D_ANIMATION", {}), "image": ("LOAD_3D_ANIMATION", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}} }}
RETURN_TYPES = ("IMAGE", "MASK", "STRING") RETURN_TYPES = ("IMAGE", "MASK", "STRING")
@@ -82,8 +78,6 @@ class Preview3D():
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}), "model_file": ("STRING", {"default": "", "multiline": False}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}} }}
OUTPUT_NODE = True OUTPUT_NODE = True
@@ -102,8 +96,6 @@ class Preview3DAnimation():
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}), "model_file": ("STRING", {"default": "", "multiline": False}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}} }}
OUTPUT_NODE = True OUTPUT_NODE = True

View File

@@ -99,12 +99,13 @@ class LTXVAddGuide:
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING", ),
"vae": ("VAE",), "vae": ("VAE",),
"latent": ("LATENT",), "latent": ("LATENT",),
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \ "image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames."
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}), "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999, "frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
"tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \ "tooltip": "Frame index to start the conditioning at. For single-frame images or "
"If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \ "videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
"Negative values are counted from the end of the video."}), "frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
"the nearest multiple of 8. Negative values are counted from the end of the video."}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
} }
} }
@@ -127,12 +128,13 @@ class LTXVAddGuide:
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
return encode_pixels, t return encode_pixels, t
def get_latent_index(self, cond, latent_length, frame_idx, scale_factors): def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors time_scale_factor, _, _ = scale_factors
_, num_keyframes = get_keyframe_idxs(cond) _, num_keyframes = get_keyframe_idxs(cond)
latent_count = latent_length - num_keyframes latent_count = latent_length - num_keyframes
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0) frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8 if guide_length > 1:
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
@@ -191,7 +193,7 @@ class LTXVAddGuide:
_, _, latent_length, latent_height, latent_width = latent_image.shape _, _, latent_length, latent_height, latent_width = latent_image.shape
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors) image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors) frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
num_prefix_frames = min(self._num_prefix_frames, t.shape[2]) num_prefix_frames = min(self._num_prefix_frames, t.shape[2])

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.23" __version__ = "0.3.26"

View File

@@ -634,6 +634,13 @@ def validate_inputs(prompt, item, validated):
continue continue
else: else:
try: try:
# Unwraps values wrapped in __value__ key. This is used to pass
# list widget value to execution, as by default list value is
# reserved to represent the connection between nodes.
if isinstance(val, dict) and "__value__" in val:
val = val["__value__"]
inputs[x] = val
if type_input == "INT": if type_input == "INT":
val = int(val) val = int(val)
inputs[x] = val inputs[x] = val

View File

@@ -139,6 +139,7 @@ from server import BinaryEventTypes
import nodes import nodes
import comfy.model_management import comfy.model_management
import comfyui_version import comfyui_version
import app.logger
def cuda_malloc_warning(): def cuda_malloc_warning():
@@ -295,9 +296,12 @@ def start_comfyui(asyncio_loop=None):
if __name__ == "__main__": if __name__ == "__main__":
# Running directly, just start ComfyUI. # Running directly, just start ComfyUI.
logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
event_loop, _, start_all_func = start_comfyui() event_loop, _, start_all_func = start_comfyui()
try: try:
event_loop.run_until_complete(start_all_func()) x = start_all_func()
app.logger.print_startup_warnings()
event_loop.run_until_complete(x)
except KeyboardInterrupt: except KeyboardInterrupt:
logging.info("\nStopped server") logging.info("\nStopped server")

View File

@@ -489,7 +489,7 @@ class SaveLatent:
file = os.path.join(full_output_folder, file) file = os.path.join(full_output_folder, file)
output = {} output = {}
output["latent_tensor"] = samples["samples"] output["latent_tensor"] = samples["samples"].contiguous()
output["latent_format_version_0"] = torch.tensor([]) output["latent_format_version_0"] = torch.tensor([])
comfy.utils.save_torch_file(output, file, metadata=metadata) comfy.utils.save_torch_file(output, file, metadata=metadata)
@@ -770,6 +770,7 @@ class VAELoader:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd) vae = comfy.sd.VAE(sd=sd)
vae.throw_exception_if_invalid()
return (vae,) return (vae,)
class ControlNetLoader: class ControlNetLoader:
@@ -1785,14 +1786,7 @@ class LoadImageOutput(LoadImage):
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration." DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
EXPERIMENTAL = True EXPERIMENTAL = True
FUNCTION = "load_image_output" FUNCTION = "load_image"
def load_image_output(self, image):
return self.load_image(f"{image} [output]")
@classmethod
def VALIDATE_INPUTS(s, image):
return True
class ImageScale: class ImageScale:

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.23" version = "0.3.26"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@@ -1,4 +1,4 @@
comfyui-frontend-package==1.10.17 comfyui-frontend-package==1.12.14
torch torch
torchsde torchsde
torchvision torchvision

View File

@@ -70,7 +70,7 @@ def test_get_release_invalid_version(mock_provider):
def test_init_frontend_default(): def test_init_frontend_default():
version_string = DEFAULT_VERSION_STRING version_string = DEFAULT_VERSION_STRING
frontend_path = FrontendManager.init_frontend(version_string) frontend_path = FrontendManager.init_frontend(version_string)
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH assert frontend_path == FrontendManager.default_frontend_path()
def test_init_frontend_invalid_version(): def test_init_frontend_invalid_version():
@@ -84,24 +84,29 @@ def test_init_frontend_invalid_provider():
with pytest.raises(HTTPError): with pytest.raises(HTTPError):
FrontendManager.init_frontend_unsafe(version_string) FrontendManager.init_frontend_unsafe(version_string)
@pytest.fixture @pytest.fixture
def mock_os_functions(): def mock_os_functions():
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \ with (
patch('app.frontend_management.os.listdir') as mock_listdir, \ patch("app.frontend_management.os.makedirs") as mock_makedirs,
patch('app.frontend_management.os.rmdir') as mock_rmdir: patch("app.frontend_management.os.listdir") as mock_listdir,
patch("app.frontend_management.os.rmdir") as mock_rmdir,
):
mock_listdir.return_value = [] # Simulate empty directory mock_listdir.return_value = [] # Simulate empty directory
yield mock_makedirs, mock_listdir, mock_rmdir yield mock_makedirs, mock_listdir, mock_rmdir
@pytest.fixture @pytest.fixture
def mock_download(): def mock_download():
with patch('app.frontend_management.download_release_asset_zip') as mock: with patch("app.frontend_management.download_release_asset_zip") as mock:
mock.side_effect = Exception("Download failed") # Simulate download failure mock.side_effect = Exception("Download failed") # Simulate download failure
yield mock yield mock
def test_finally_block(mock_os_functions, mock_download, mock_provider): def test_finally_block(mock_os_functions, mock_download, mock_provider):
# Arrange # Arrange
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
version_string = 'test-owner/test-repo@1.0.0' version_string = "test-owner/test-repo@1.0.0"
# Act & Assert # Act & Assert
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -128,3 +133,42 @@ def test_parse_version_string_invalid():
version_string = "invalid" version_string = "invalid"
with pytest.raises(argparse.ArgumentTypeError): with pytest.raises(argparse.ArgumentTypeError):
FrontendManager.parse_version_string(version_string) FrontendManager.parse_version_string(version_string)
def test_init_frontend_default_with_mocks():
# Arrange
version_string = DEFAULT_VERSION_STRING
# Act
with (
patch("app.frontend_management.check_frontend_version") as mock_check,
patch.object(
FrontendManager, "default_frontend_path", return_value="/mocked/path"
),
):
frontend_path = FrontendManager.init_frontend(version_string)
# Assert
assert frontend_path == "/mocked/path"
mock_check.assert_called_once()
def test_init_frontend_fallback_on_error():
# Arrange
version_string = "test-owner/test-repo@1.0.0"
# Act
with (
patch.object(
FrontendManager, "init_frontend_unsafe", side_effect=Exception("Test error")
),
patch("app.frontend_management.check_frontend_version") as mock_check,
patch.object(
FrontendManager, "default_frontend_path", return_value="/default/path"
),
):
frontend_path = FrontendManager.init_frontend(version_string)
# Assert
assert frontend_path == "/default/path"
mock_check.assert_called_once()