Compare commits
59 Commits
v0.3.23
...
weight-zip
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8037ab667 | ||
|
|
3b19fc76e3 | ||
|
|
50614f1b79 | ||
|
|
6dc7b0bfe3 | ||
|
|
e8e990d6b8 | ||
|
|
2e24a15905 | ||
|
|
fd5297131f | ||
|
|
55a1b09ddc | ||
|
|
3c3988df45 | ||
|
|
7ebd8087ff | ||
|
|
c624c29d66 | ||
|
|
a2448fc527 | ||
|
|
6a0daa79b6 | ||
|
|
9c98c6358b | ||
|
|
7aceb9f91c | ||
|
|
35504e2f93 | ||
|
|
299436cfed | ||
|
|
52e566d2bc | ||
|
|
9b6cd9b874 | ||
|
|
3fc688aebd | ||
|
|
f4411250f3 | ||
|
|
d2a0fb6bb0 | ||
|
|
01015bff16 | ||
|
|
2330754b0e | ||
|
|
bc219a6487 | ||
|
|
94689766ad | ||
|
|
cfbe4b49ca | ||
|
|
ca8efab79f | ||
|
|
65ea778a5e | ||
|
|
db9f2a34fc | ||
|
|
7946049794 | ||
|
|
6f6349b6a7 | ||
|
|
1f138dd382 | ||
|
|
b779349b55 | ||
|
|
35e2dcf5d7 | ||
|
|
67c7184b74 | ||
|
|
6f8e766509 | ||
|
|
e1da98a14a | ||
|
|
a73410aafa | ||
|
|
9aac21f894 | ||
|
|
528d1b3563 | ||
|
|
2bc4b5968f | ||
|
|
7395b0c0d1 | ||
|
|
0952569493 | ||
|
|
29832b3b61 | ||
|
|
be4e760648 | ||
|
|
c3d9cc4592 | ||
|
|
84cc9cb528 | ||
|
|
ebbb920163 | ||
|
|
d60fe0af4a | ||
|
|
5dbd250965 | ||
|
|
4ab1875283 | ||
|
|
11b1f27cb1 | ||
|
|
70e15fd743 | ||
|
|
e1474150de | ||
|
|
e62d72e8ca | ||
|
|
1650cda030 | ||
|
|
a13125840c | ||
|
|
dfa36e6855 |
@@ -0,0 +1,2 @@
|
|||||||
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||||
|
pause
|
||||||
@@ -7,7 +7,7 @@ on:
|
|||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "126"
|
default: "128"
|
||||||
|
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
@@ -19,7 +19,7 @@ on:
|
|||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "1"
|
default: "2"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
@@ -34,7 +34,7 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 30
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
@@ -74,7 +74,7 @@ jobs:
|
|||||||
pause" > ./update/update_comfyui_and_python_dependencies.bat
|
pause" > ./update/update_comfyui_and_python_dependencies.bat
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||||
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
|
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
|
||||||
|
|
||||||
cd ComfyUI_windows_portable_nightly_pytorch
|
cd ComfyUI_windows_portable_nightly_pytorch
|
||||||
|
|||||||
@@ -19,5 +19,6 @@
|
|||||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
|
|
||||||
# Extra nodes
|
# Node developers
|
||||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
|
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||||
|
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||||
|
|||||||
@@ -215,9 +215,9 @@ Nvidia users should install stable pytorch using this command:
|
|||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
|
||||||
|
|
||||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements.
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
||||||
|
|
||||||
#### Troubleshooting
|
#### Troubleshooting
|
||||||
|
|
||||||
|
|||||||
@@ -11,20 +11,43 @@ from dataclasses import dataclass
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypedDict, Optional
|
from typing import TypedDict, Optional
|
||||||
|
from importlib.metadata import version
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||||
|
import app.logger
|
||||||
|
|
||||||
|
# The path to the requirements.txt file
|
||||||
|
req_path = Path(__file__).parents[1] / "requirements.txt"
|
||||||
|
|
||||||
|
def frontend_install_warning_message():
|
||||||
|
"""The warning message to display when the frontend version is not up to date."""
|
||||||
|
|
||||||
|
extra = ""
|
||||||
|
if sys.flags.no_user_site:
|
||||||
|
extra = "-s "
|
||||||
|
return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem"
|
||||||
|
|
||||||
|
|
||||||
try:
|
def check_frontend_version():
|
||||||
import comfyui_frontend_package
|
"""Check if the frontend version is up to date."""
|
||||||
except ImportError:
|
|
||||||
# TODO: Remove the check after roll out of 0.3.16
|
def parse_version(version: str) -> tuple[int, int, int]:
|
||||||
req_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))
|
return tuple(map(int, version.split(".")))
|
||||||
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. Please install the updated requirements.txt file by running:\n{sys.executable} -m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n")
|
|
||||||
exit(-1)
|
try:
|
||||||
|
frontend_version_str = version("comfyui-frontend-package")
|
||||||
|
frontend_version = parse_version(frontend_version_str)
|
||||||
|
with open(req_path, "r", encoding="utf-8") as f:
|
||||||
|
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||||
|
if frontend_version < required_frontend:
|
||||||
|
app.logger.log_startup_warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message()))
|
||||||
|
else:
|
||||||
|
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to check frontend version: {e}")
|
||||||
|
|
||||||
|
|
||||||
REQUEST_TIMEOUT = 10 # seconds
|
REQUEST_TIMEOUT = 10 # seconds
|
||||||
@@ -121,9 +144,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
class FrontendManager:
|
class FrontendManager:
|
||||||
DEFAULT_FRONTEND_PATH = str(importlib.resources.files(comfyui_frontend_package) / "static")
|
|
||||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default_frontend_path(cls) -> str:
|
||||||
|
try:
|
||||||
|
import comfyui_frontend_package
|
||||||
|
return str(importlib.resources.files(comfyui_frontend_package) / "static")
|
||||||
|
except ImportError:
|
||||||
|
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n")
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
@@ -160,7 +191,8 @@ class FrontendManager:
|
|||||||
main error source might be request timeout or invalid URL.
|
main error source might be request timeout or invalid URL.
|
||||||
"""
|
"""
|
||||||
if version_string == DEFAULT_VERSION_STRING:
|
if version_string == DEFAULT_VERSION_STRING:
|
||||||
return cls.DEFAULT_FRONTEND_PATH
|
check_frontend_version()
|
||||||
|
return cls.default_frontend_path()
|
||||||
|
|
||||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||||
|
|
||||||
@@ -213,4 +245,5 @@ class FrontendManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Failed to initialize frontend: %s", e)
|
logging.error("Failed to initialize frontend: %s", e)
|
||||||
logging.info("Falling back to the default frontend.")
|
logging.info("Falling back to the default frontend.")
|
||||||
return cls.DEFAULT_FRONTEND_PATH
|
check_frontend_version()
|
||||||
|
return cls.default_frontend_path()
|
||||||
|
|||||||
@@ -82,3 +82,17 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
|
|||||||
logger.addHandler(stdout_handler)
|
logger.addHandler(stdout_handler)
|
||||||
|
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
|
||||||
|
STARTUP_WARNINGS = []
|
||||||
|
|
||||||
|
|
||||||
|
def log_startup_warning(msg):
|
||||||
|
logging.warning(msg)
|
||||||
|
STARTUP_WARNINGS.append(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def print_startup_warnings():
|
||||||
|
for s in STARTUP_WARNINGS:
|
||||||
|
logging.warning(s)
|
||||||
|
STARTUP_WARNINGS.clear()
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
|
|||||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||||
|
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import comfy.model_patcher
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.clip_model
|
import comfy.clip_model
|
||||||
|
import comfy.image_encoders.dino2
|
||||||
|
|
||||||
class Output:
|
class Output:
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
@@ -34,6 +35,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
|
|||||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||||
|
|
||||||
|
IMAGE_ENCODERS = {
|
||||||
|
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
|
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
|
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||||
|
}
|
||||||
|
|
||||||
class ClipVisionModel():
|
class ClipVisionModel():
|
||||||
def __init__(self, json_config):
|
def __init__(self, json_config):
|
||||||
with open(json_config) as f:
|
with open(json_config) as f:
|
||||||
@@ -42,10 +49,11 @@ class ClipVisionModel():
|
|||||||
self.image_size = config.get("image_size", 224)
|
self.image_size = config.get("image_size", 224)
|
||||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||||
|
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
|
||||||
self.load_device = comfy.model_management.text_encoder_device()
|
self.load_device = comfy.model_management.text_encoder_device()
|
||||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||||
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
@@ -111,6 +119,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||||
else:
|
else:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||||
|
elif "embeddings.patch_embeddings.projection.weight" in sd:
|
||||||
|
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Literal, TypedDict
|
from typing import Literal, TypedDict
|
||||||
|
from typing_extensions import NotRequired
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@@ -26,6 +27,7 @@ class IO(StrEnum):
|
|||||||
BOOLEAN = "BOOLEAN"
|
BOOLEAN = "BOOLEAN"
|
||||||
INT = "INT"
|
INT = "INT"
|
||||||
FLOAT = "FLOAT"
|
FLOAT = "FLOAT"
|
||||||
|
COMBO = "COMBO"
|
||||||
CONDITIONING = "CONDITIONING"
|
CONDITIONING = "CONDITIONING"
|
||||||
SAMPLER = "SAMPLER"
|
SAMPLER = "SAMPLER"
|
||||||
SIGMAS = "SIGMAS"
|
SIGMAS = "SIGMAS"
|
||||||
@@ -66,6 +68,7 @@ class IO(StrEnum):
|
|||||||
b = frozenset(value.split(","))
|
b = frozenset(value.split(","))
|
||||||
return not (b.issubset(a) or a.issubset(b))
|
return not (b.issubset(a) or a.issubset(b))
|
||||||
|
|
||||||
|
|
||||||
class RemoteInputOptions(TypedDict):
|
class RemoteInputOptions(TypedDict):
|
||||||
route: str
|
route: str
|
||||||
"""The route to the remote source."""
|
"""The route to the remote source."""
|
||||||
@@ -80,6 +83,14 @@ class RemoteInputOptions(TypedDict):
|
|||||||
refresh: int
|
refresh: int
|
||||||
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
|
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
|
||||||
|
|
||||||
|
|
||||||
|
class MultiSelectOptions(TypedDict):
|
||||||
|
placeholder: NotRequired[str]
|
||||||
|
"""The placeholder text to display in the multi-select widget when no items are selected."""
|
||||||
|
chip: NotRequired[bool]
|
||||||
|
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
|
||||||
|
|
||||||
|
|
||||||
class InputTypeOptions(TypedDict):
|
class InputTypeOptions(TypedDict):
|
||||||
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
||||||
|
|
||||||
@@ -114,7 +125,7 @@ class InputTypeOptions(TypedDict):
|
|||||||
# default: bool
|
# default: bool
|
||||||
label_on: str
|
label_on: str
|
||||||
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
||||||
label_on: str
|
label_off: str
|
||||||
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
||||||
# class InputTypeString(InputTypeOptions):
|
# class InputTypeString(InputTypeOptions):
|
||||||
# default: str
|
# default: str
|
||||||
@@ -133,9 +144,22 @@ class InputTypeOptions(TypedDict):
|
|||||||
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
||||||
"""
|
"""
|
||||||
remote: RemoteInputOptions
|
remote: RemoteInputOptions
|
||||||
"""Specifies the configuration for a remote input."""
|
"""Specifies the configuration for a remote input.
|
||||||
|
Available after ComfyUI frontend v1.9.7
|
||||||
|
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
|
||||||
control_after_generate: bool
|
control_after_generate: bool
|
||||||
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
|
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
|
||||||
|
options: NotRequired[list[str | int | float]]
|
||||||
|
"""COMBO type only. Specifies the selectable options for the combo widget.
|
||||||
|
Prefer:
|
||||||
|
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
|
||||||
|
Over:
|
||||||
|
[["Option 1", "Option 2", "Option 3"]]
|
||||||
|
"""
|
||||||
|
multi_select: NotRequired[MultiSelectOptions]
|
||||||
|
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||||
|
Available after ComfyUI frontend v1.13.4
|
||||||
|
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||||
|
|
||||||
|
|
||||||
class HiddenInputTypeDict(TypedDict):
|
class HiddenInputTypeDict(TypedDict):
|
||||||
|
|||||||
141
comfy/image_encoders/dino2.py
Normal file
141
comfy/image_encoders/dino2.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
import torch
|
||||||
|
from comfy.text_encoders.bert import BertAttention
|
||||||
|
import comfy.model_management
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2AttentionOutput(torch.nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.dense(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2AttentionBlock(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
|
||||||
|
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
||||||
|
|
||||||
|
def forward(self, x, mask, optimized_attention):
|
||||||
|
return self.output(self.attention(x, mask, optimized_attention))
|
||||||
|
|
||||||
|
|
||||||
|
class LayerScale(torch.nn.Module):
|
||||||
|
def __init__(self, dim, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLUFFN(torch.nn.Module):
|
||||||
|
def __init__(self, dim, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
in_features = out_features = dim
|
||||||
|
hidden_features = int(dim * 4)
|
||||||
|
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||||
|
|
||||||
|
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
|
||||||
|
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.weights_in(x)
|
||||||
|
x1, x2 = x.chunk(2, dim=-1)
|
||||||
|
x = torch.nn.functional.silu(x1) * x2
|
||||||
|
return self.weights_out(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2Block(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||||
|
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||||
|
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||||
|
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||||
|
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x, optimized_attention):
|
||||||
|
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
|
||||||
|
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2Encoder(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
||||||
|
|
||||||
|
def forward(self, x, intermediate_output=None):
|
||||||
|
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||||
|
|
||||||
|
if intermediate_output is not None:
|
||||||
|
if intermediate_output < 0:
|
||||||
|
intermediate_output = len(self.layer) + intermediate_output
|
||||||
|
|
||||||
|
intermediate = None
|
||||||
|
for i, l in enumerate(self.layer):
|
||||||
|
x = l(x, optimized_attention)
|
||||||
|
if i == intermediate_output:
|
||||||
|
intermediate = x.clone()
|
||||||
|
return x, intermediate
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2PatchEmbeddings(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.projection = operations.Conv2d(
|
||||||
|
in_channels=num_channels,
|
||||||
|
out_channels=dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, pixel_values):
|
||||||
|
return self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2Embeddings(torch.nn.Module):
|
||||||
|
def __init__(self, dim, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
patch_size = 14
|
||||||
|
image_size = 518
|
||||||
|
|
||||||
|
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
||||||
|
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
|
||||||
|
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
def forward(self, pixel_values):
|
||||||
|
x = self.patch_embeddings(pixel_values)
|
||||||
|
# TODO: mask_token?
|
||||||
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
|
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Dinov2Model(torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
num_layers = config_dict["num_hidden_layers"]
|
||||||
|
dim = config_dict["hidden_size"]
|
||||||
|
heads = config_dict["num_attention_heads"]
|
||||||
|
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||||
|
|
||||||
|
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||||
|
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
||||||
|
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||||
|
x = self.embeddings(pixel_values)
|
||||||
|
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
||||||
|
x = self.layernorm(x)
|
||||||
|
pooled_output = x[:, 0, :]
|
||||||
|
return x, i, pooled_output, None
|
||||||
21
comfy/image_encoders/dino2_giant.json
Normal file
21
comfy/image_encoders/dino2_giant.json
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"attention_probs_dropout_prob": 0.0,
|
||||||
|
"drop_path_rate": 0.0,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout_prob": 0.0,
|
||||||
|
"hidden_size": 1536,
|
||||||
|
"image_size": 518,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"layer_norm_eps": 1e-06,
|
||||||
|
"layerscale_value": 1.0,
|
||||||
|
"mlp_ratio": 4,
|
||||||
|
"model_type": "dinov2",
|
||||||
|
"num_attention_heads": 24,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_hidden_layers": 40,
|
||||||
|
"patch_size": 14,
|
||||||
|
"qkv_bias": true,
|
||||||
|
"use_swiglu_ffn": true,
|
||||||
|
"image_mean": [0.485, 0.456, 0.406],
|
||||||
|
"image_std": [0.229, 0.224, 0.225]
|
||||||
|
}
|
||||||
@@ -688,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
|||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
sigma_fn = lambda t: t.neg().exp()
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
t_fn = lambda sigma: sigma.log().neg()
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
@@ -762,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
if solver_type not in {'heun', 'midpoint'}:
|
if solver_type not in {'heun', 'midpoint'}:
|
||||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
old_denoised = None
|
old_denoised = None
|
||||||
@@ -808,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
denoised_1, denoised_2 = None, None
|
denoised_1, denoised_2 = None, None
|
||||||
@@ -858,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return x
|
return x
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||||
@@ -867,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
|||||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return x
|
return x
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
@@ -876,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
|||||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return x
|
return x
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
||||||
@@ -1366,3 +1366,59 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
|
|||||||
x = x + d_bar * dt
|
x = x + d_bar * dt
|
||||||
old_d = d
|
old_d = d
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||||
|
"""
|
||||||
|
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
|
||||||
|
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||||
|
"""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
def default_noise_scaler(sigma):
|
||||||
|
return sigma * ((sigma ** 0.3).exp() + 10.0)
|
||||||
|
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
|
||||||
|
num_integration_points = 200.0
|
||||||
|
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
||||||
|
|
||||||
|
old_denoised = None
|
||||||
|
old_denoised_d = None
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
stage_used = min(max_stage, i + 1)
|
||||||
|
if sigmas[i + 1] == 0:
|
||||||
|
x = denoised
|
||||||
|
elif stage_used == 1:
|
||||||
|
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||||
|
x = r * x + (1 - r) * denoised
|
||||||
|
else:
|
||||||
|
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||||
|
x = r * x + (1 - r) * denoised
|
||||||
|
|
||||||
|
dt = sigmas[i + 1] - sigmas[i]
|
||||||
|
sigma_step_size = -dt / num_integration_points
|
||||||
|
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
|
||||||
|
scaled_pos = noise_scaler(sigma_pos)
|
||||||
|
|
||||||
|
# Stage 2
|
||||||
|
s = torch.sum(1 / scaled_pos) * sigma_step_size
|
||||||
|
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
|
||||||
|
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
|
||||||
|
|
||||||
|
if stage_used >= 3:
|
||||||
|
# Stage 3
|
||||||
|
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
|
||||||
|
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
|
||||||
|
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
|
||||||
|
old_denoised_d = denoised_d
|
||||||
|
|
||||||
|
if s_noise != 0 and sigmas[i + 1] > 0:
|
||||||
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||||
|
old_denoised = denoised
|
||||||
|
return x
|
||||||
|
|||||||
@@ -19,6 +19,10 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.autograd import Function
|
from torch.autograd import Function
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
class vector_quantize(Function):
|
class vector_quantize(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -121,15 +125,15 @@ class ResBlock(nn.Module):
|
|||||||
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||||
self.depthwise = nn.Sequential(
|
self.depthwise = nn.Sequential(
|
||||||
nn.ReplicationPad2d(1),
|
nn.ReplicationPad2d(1),
|
||||||
nn.Conv2d(c, c, kernel_size=3, groups=c)
|
ops.Conv2d(c, c, kernel_size=3, groups=c)
|
||||||
)
|
)
|
||||||
|
|
||||||
# channelwise
|
# channelwise
|
||||||
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||||
self.channelwise = nn.Sequential(
|
self.channelwise = nn.Sequential(
|
||||||
nn.Linear(c, c_hidden),
|
ops.Linear(c, c_hidden),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(c_hidden, c),
|
ops.Linear(c_hidden, c),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||||
@@ -171,16 +175,16 @@ class StageA(nn.Module):
|
|||||||
# Encoder blocks
|
# Encoder blocks
|
||||||
self.in_block = nn.Sequential(
|
self.in_block = nn.Sequential(
|
||||||
nn.PixelUnshuffle(2),
|
nn.PixelUnshuffle(2),
|
||||||
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
ops.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||||
)
|
)
|
||||||
down_blocks = []
|
down_blocks = []
|
||||||
for i in range(levels):
|
for i in range(levels):
|
||||||
if i > 0:
|
if i > 0:
|
||||||
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||||
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
||||||
down_blocks.append(block)
|
down_blocks.append(block)
|
||||||
down_blocks.append(nn.Sequential(
|
down_blocks.append(nn.Sequential(
|
||||||
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||||
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
||||||
))
|
))
|
||||||
self.down_blocks = nn.Sequential(*down_blocks)
|
self.down_blocks = nn.Sequential(*down_blocks)
|
||||||
@@ -191,7 +195,7 @@ class StageA(nn.Module):
|
|||||||
|
|
||||||
# Decoder blocks
|
# Decoder blocks
|
||||||
up_blocks = [nn.Sequential(
|
up_blocks = [nn.Sequential(
|
||||||
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
ops.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||||
)]
|
)]
|
||||||
for i in range(levels):
|
for i in range(levels):
|
||||||
for j in range(bottleneck_blocks if i == 0 else 1):
|
for j in range(bottleneck_blocks if i == 0 else 1):
|
||||||
@@ -199,11 +203,11 @@ class StageA(nn.Module):
|
|||||||
up_blocks.append(block)
|
up_blocks.append(block)
|
||||||
if i < levels - 1:
|
if i < levels - 1:
|
||||||
up_blocks.append(
|
up_blocks.append(
|
||||||
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||||
padding=1))
|
padding=1))
|
||||||
self.up_blocks = nn.Sequential(*up_blocks)
|
self.up_blocks = nn.Sequential(*up_blocks)
|
||||||
self.out_block = nn.Sequential(
|
self.out_block = nn.Sequential(
|
||||||
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||||
nn.PixelShuffle(2),
|
nn.PixelShuffle(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -232,17 +236,17 @@ class Discriminator(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
d = max(depth - 3, 3)
|
d = max(depth - 3, 3)
|
||||||
layers = [
|
layers = [
|
||||||
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||||
nn.LeakyReLU(0.2),
|
nn.LeakyReLU(0.2),
|
||||||
]
|
]
|
||||||
for i in range(depth - 1):
|
for i in range(depth - 1):
|
||||||
c_in = c_hidden // (2 ** max((d - i), 0))
|
c_in = c_hidden // (2 ** max((d - i), 0))
|
||||||
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
||||||
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||||
layers.append(nn.InstanceNorm2d(c_out))
|
layers.append(nn.InstanceNorm2d(c_out))
|
||||||
layers.append(nn.LeakyReLU(0.2))
|
layers.append(nn.LeakyReLU(0.2))
|
||||||
self.encoder = nn.Sequential(*layers)
|
self.encoder = nn.Sequential(*layers)
|
||||||
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||||
self.logits = nn.Sigmoid()
|
self.logits = nn.Sigmoid()
|
||||||
|
|
||||||
def forward(self, x, cond=None):
|
def forward(self, x, cond=None):
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ import torch
|
|||||||
import torchvision
|
import torchvision
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
# EfficientNet
|
# EfficientNet
|
||||||
class EfficientNetEncoder(nn.Module):
|
class EfficientNetEncoder(nn.Module):
|
||||||
@@ -26,7 +29,7 @@ class EfficientNetEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
||||||
self.mapper = nn.Sequential(
|
self.mapper = nn.Sequential(
|
||||||
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
ops.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||||
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
||||||
)
|
)
|
||||||
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
|
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
|
||||||
@@ -34,7 +37,7 @@ class EfficientNetEncoder(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x * 0.5 + 0.5
|
x = x * 0.5 + 0.5
|
||||||
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
|
x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype)
|
||||||
o = self.mapper(self.backbone(x))
|
o = self.mapper(self.backbone(x))
|
||||||
return o
|
return o
|
||||||
|
|
||||||
@@ -44,39 +47,39 @@ class Previewer(nn.Module):
|
|||||||
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.blocks = nn.Sequential(
|
self.blocks = nn.Sequential(
|
||||||
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden),
|
nn.BatchNorm2d(c_hidden),
|
||||||
|
|
||||||
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden),
|
nn.BatchNorm2d(c_hidden),
|
||||||
|
|
||||||
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 2),
|
nn.BatchNorm2d(c_hidden // 2),
|
||||||
|
|
||||||
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 2),
|
nn.BatchNorm2d(c_hidden // 2),
|
||||||
|
|
||||||
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 4),
|
nn.BatchNorm2d(c_hidden // 4),
|
||||||
|
|
||||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 4),
|
nn.BatchNorm2d(c_hidden // 4),
|
||||||
|
|
||||||
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 4),
|
nn.BatchNorm2d(c_hidden // 4),
|
||||||
|
|
||||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 4),
|
nn.BatchNorm2d(c_hidden // 4),
|
||||||
|
|
||||||
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
ops.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|||||||
@@ -105,7 +105,9 @@ class Modulation(nn.Module):
|
|||||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, vec: Tensor) -> tuple:
|
def forward(self, vec: Tensor) -> tuple:
|
||||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
if vec.ndim == 2:
|
||||||
|
vec = vec[:, None, :]
|
||||||
|
out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
ModulationOut(*out[:3]),
|
ModulationOut(*out[:3]),
|
||||||
@@ -113,6 +115,20 @@ class Modulation(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||||
|
if modulation_dims is None:
|
||||||
|
if m_add is not None:
|
||||||
|
return tensor * m_mult + m_add
|
||||||
|
else:
|
||||||
|
return tensor * m_mult
|
||||||
|
else:
|
||||||
|
for d in modulation_dims:
|
||||||
|
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
|
||||||
|
if m_add is not None:
|
||||||
|
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
class DoubleStreamBlock(nn.Module):
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -143,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
self.flipped_img_txt = flipped_img_txt
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
img_modulated = self.img_norm1(img)
|
img_modulated = self.img_norm1(img)
|
||||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||||
img_qkv = self.img_attn.qkv(img_modulated)
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
# prepare txt for attention
|
# prepare txt for attention
|
||||||
txt_modulated = self.txt_norm1(txt)
|
txt_modulated = self.txt_norm1(txt)
|
||||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
@@ -179,12 +195,12 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
@@ -228,9 +244,9 @@ class SingleStreamBlock(nn.Module):
|
|||||||
self.mlp_act = nn.GELU(approximate="tanh")
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
|
||||||
mod, _ = self.modulation(vec)
|
mod, _ = self.modulation(vec)
|
||||||
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
@@ -239,7 +255,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x += mod.gate * output
|
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
return x
|
return x
|
||||||
@@ -252,8 +268,11 @@ class LastLayer(nn.Module):
|
|||||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||||
|
|
||||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
|
||||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
if vec.ndim == 2:
|
||||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
vec = vec[:, None, :]
|
||||||
|
|
||||||
|
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
|
||||||
|
x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims)
|
||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -10,10 +10,11 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
|||||||
q_shape = q.shape
|
q_shape = q.shape
|
||||||
k_shape = k.shape
|
k_shape = k.shape
|
||||||
|
|
||||||
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
|
if pe is not None:
|
||||||
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
|
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||||
|
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||||
|
|
||||||
heads = q.shape[1]
|
heads = q.shape[1]
|
||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||||
@@ -36,8 +37,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
|
|
||||||
|
|
||||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
|||||||
@@ -115,8 +115,11 @@ class Flux(nn.Module):
|
|||||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
if img_ids is not None:
|
||||||
pe = self.pe_embedder(ids)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
else:
|
||||||
|
pe = None
|
||||||
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.double_blocks):
|
for i, block in enumerate(self.double_blocks):
|
||||||
|
|||||||
@@ -227,6 +227,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
timesteps: Tensor,
|
timesteps: Tensor,
|
||||||
y: Tensor,
|
y: Tensor,
|
||||||
guidance: Tensor = None,
|
guidance: Tensor = None,
|
||||||
|
guiding_frame_index=None,
|
||||||
control=None,
|
control=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
@@ -237,7 +238,17 @@ class HunyuanVideo(nn.Module):
|
|||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||||
|
|
||||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
if guiding_frame_index is not None:
|
||||||
|
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||||
|
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
|
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||||
|
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
||||||
|
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
||||||
|
modulation_dims_txt = [(0, None, 1)]
|
||||||
|
else:
|
||||||
|
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
|
modulation_dims = None
|
||||||
|
modulation_dims_txt = None
|
||||||
|
|
||||||
if self.params.guidance_embed:
|
if self.params.guidance_embed:
|
||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
@@ -264,14 +275,14 @@ class HunyuanVideo(nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_i = control.get("input")
|
control_i = control.get("input")
|
||||||
@@ -286,13 +297,13 @@ class HunyuanVideo(nn.Module):
|
|||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_o = control.get("output")
|
control_o = control.get("output")
|
||||||
@@ -303,7 +314,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
|
|
||||||
img = img[:, : img_len]
|
img = img[:, : img_len]
|
||||||
|
|
||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||||
|
|
||||||
shape = initial_shape[-3:]
|
shape = initial_shape[-3:]
|
||||||
for i in range(len(shape)):
|
for i in range(len(shape)):
|
||||||
@@ -313,7 +324,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||||
@@ -325,5 +336,5 @@ class HunyuanVideo(nn.Module):
|
|||||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
|
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -24,6 +24,13 @@ if model_management.sage_attention_enabled():
|
|||||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
|
if model_management.flash_attention_enabled():
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_func
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
@@ -496,6 +503,63 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||||
|
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||||
|
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
||||||
|
|
||||||
|
|
||||||
|
@flash_attn_wrapper.register_fake
|
||||||
|
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
|
||||||
|
# Output shape is the same as q
|
||||||
|
return q.new_empty(q.shape)
|
||||||
|
except AttributeError as error:
|
||||||
|
FLASH_ATTN_ERROR = error
|
||||||
|
|
||||||
|
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||||
|
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||||
|
|
||||||
|
|
||||||
|
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
if skip_reshape:
|
||||||
|
b, _, _, dim_head = q.shape
|
||||||
|
else:
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
# add a batch dimension if there isn't already one
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
# add a heads dimension if there isn't already one
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert mask is None
|
||||||
|
out = flash_attn_wrapper(
|
||||||
|
q.transpose(1, 2),
|
||||||
|
k.transpose(1, 2),
|
||||||
|
v.transpose(1, 2),
|
||||||
|
dropout_p=0.0,
|
||||||
|
causal=False,
|
||||||
|
).transpose(1, 2)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
|
if not skip_output_reshape:
|
||||||
|
out = (
|
||||||
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
|
||||||
if model_management.sage_attention_enabled():
|
if model_management.sage_attention_enabled():
|
||||||
@@ -504,6 +568,9 @@ if model_management.sage_attention_enabled():
|
|||||||
elif model_management.xformers_enabled():
|
elif model_management.xformers_enabled():
|
||||||
logging.info("Using xformers attention")
|
logging.info("Using xformers attention")
|
||||||
optimized_attention = attention_xformers
|
optimized_attention = attention_xformers
|
||||||
|
elif model_management.flash_attention_enabled():
|
||||||
|
logging.info("Using Flash Attention")
|
||||||
|
optimized_attention = attention_flash
|
||||||
elif model_management.pytorch_attention_enabled():
|
elif model_management.pytorch_attention_enabled():
|
||||||
logging.info("Using pytorch attention")
|
logging.info("Using pytorch attention")
|
||||||
optimized_attention = attention_pytorch
|
optimized_attention = attention_pytorch
|
||||||
|
|||||||
@@ -384,6 +384,7 @@ class WanModel(torch.nn.Module):
|
|||||||
context,
|
context,
|
||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
freqs=None,
|
freqs=None,
|
||||||
|
transformer_options={},
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Forward pass through the diffusion model
|
Forward pass through the diffusion model
|
||||||
@@ -423,14 +424,18 @@ class WanModel(torch.nn.Module):
|
|||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
context = torch.concat([context_clip, context], dim=1)
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
|
||||||
# arguments
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
kwargs = dict(
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
e=e0,
|
for i, block in enumerate(self.blocks):
|
||||||
freqs=freqs,
|
if ("double_block", i) in blocks_replace:
|
||||||
context=context)
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
for block in self.blocks:
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||||
x = block(x, **kwargs)
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context)
|
||||||
|
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
@@ -439,7 +444,7 @@ class WanModel(torch.nn.Module):
|
|||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
|
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
@@ -453,7 +458,7 @@ class WanModel(torch.nn.Module):
|
|||||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
|
|
||||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
@@ -104,11 +105,11 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
self.current_patcher: 'ModelPatcher' = None
|
self.current_patcher: ModelPatcher = None
|
||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
|
fp8 = model_config.optimizations.get("fp8", False)
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
@@ -128,6 +129,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
logging.info("model_type {}".format(model_type.name))
|
logging.info("model_type {}".format(model_type.name))
|
||||||
logging.debug("adm {}".format(self.adm_channels))
|
logging.debug("adm {}".format(self.adm_channels))
|
||||||
self.memory_usage_factor = model_config.memory_usage_factor
|
self.memory_usage_factor = model_config.memory_usage_factor
|
||||||
|
self.zipper_initialized = False
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
@@ -137,6 +139,16 @@ class BaseModel(torch.nn.Module):
|
|||||||
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
# handle lowvram zipper initialization, if required
|
||||||
|
if self.model_lowvram and not self.zipper_initialized:
|
||||||
|
if self.current_patcher:
|
||||||
|
self.zipper_initialized = True
|
||||||
|
with self.current_patcher.use_ejected():
|
||||||
|
loading = self.current_patcher._load_list_lowvram_only()
|
||||||
|
|
||||||
|
return self._apply_model_inner(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _apply_model_inner(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
sigma = t
|
sigma = t
|
||||||
xc = self.model_sampling.calculate_input(sigma, x)
|
xc = self.model_sampling.calculate_input(sigma, x)
|
||||||
if c_concat is not None:
|
if c_concat is not None:
|
||||||
@@ -898,20 +910,31 @@ class HunyuanVideo(BaseModel):
|
|||||||
guidance = kwargs.get("guidance", 6.0)
|
guidance = kwargs.get("guidance", 6.0)
|
||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
|
|
||||||
|
guiding_frame_index = kwargs.get("guiding_frame_index", None)
|
||||||
|
if guiding_frame_index is not None:
|
||||||
|
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||||
|
return latent_image
|
||||||
|
|
||||||
class HunyuanVideoI2V(HunyuanVideo):
|
class HunyuanVideoI2V(HunyuanVideo):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device)
|
super().__init__(model_config, model_type, device=device)
|
||||||
self.concat_keys = ("concat_image", "mask_inverted")
|
self.concat_keys = ("concat_image", "mask_inverted")
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||||
|
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
|
||||||
|
|
||||||
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
|
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device)
|
super().__init__(model_config, model_type, device=device)
|
||||||
self.concat_keys = ("concat_image",)
|
self.concat_keys = ("concat_image",)
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||||
|
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
|
||||||
|
|
||||||
class CosmosVideo(BaseModel):
|
class CosmosVideo(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
|
||||||
@@ -962,11 +985,11 @@ class WAN21(BaseModel):
|
|||||||
self.image_to_video = image_to_video
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
def concat_cond(self, **kwargs):
|
def concat_cond(self, **kwargs):
|
||||||
if not self.image_to_video:
|
noise = kwargs.get("noise", None)
|
||||||
|
if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
image = kwargs.get("concat_latent_image", None)
|
image = kwargs.get("concat_latent_image", None)
|
||||||
noise = kwargs.get("noise", None)
|
|
||||||
device = kwargs["device"]
|
device = kwargs["device"]
|
||||||
|
|
||||||
if image is None:
|
if image is None:
|
||||||
@@ -976,6 +999,9 @@ class WAN21(BaseModel):
|
|||||||
image = self.process_latent_in(image)
|
image = self.process_latent_in(image)
|
||||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
|
|
||||||
|
if not self.image_to_video:
|
||||||
|
return image
|
||||||
|
|
||||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = torch.zeros_like(noise)[:, :4]
|
mask = torch.zeros_like(noise)[:, :4]
|
||||||
|
|||||||
@@ -471,6 +471,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||||
if model_config.scaled_fp8 == torch.float32:
|
if model_config.scaled_fp8 == torch.float32:
|
||||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||||
|
if scaled_fp8_weight.nelement() == 2:
|
||||||
|
model_config.optimizations["fp8"] = False
|
||||||
|
else:
|
||||||
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
|||||||
@@ -186,12 +186,21 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
else:
|
else:
|
||||||
return mem_total
|
return mem_total
|
||||||
|
|
||||||
|
def mac_version():
|
||||||
|
try:
|
||||||
|
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info("pytorch version: {}".format(torch_version))
|
logging.info("pytorch version: {}".format(torch_version))
|
||||||
|
mac_ver = mac_version()
|
||||||
|
if mac_ver is not None:
|
||||||
|
logging.info("Mac Version {}".format(mac_ver))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -581,7 +590,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
loaded_memory = loaded_model.model_loaded_memory()
|
loaded_memory = loaded_model.model_loaded_memory()
|
||||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||||
|
|
||||||
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
@@ -921,6 +930,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
def sage_attention_enabled():
|
def sage_attention_enabled():
|
||||||
return args.use_sage_attention
|
return args.use_sage_attention
|
||||||
|
|
||||||
|
def flash_attention_enabled():
|
||||||
|
return args.use_flash_attention
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
@@ -969,12 +981,6 @@ def pytorch_attention_flash_attention():
|
|||||||
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def mac_version():
|
|
||||||
try:
|
|
||||||
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def force_upcast_attention_dtype():
|
def force_upcast_attention_dtype():
|
||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Optional, Callable
|
from typing import Optional, Callable, TYPE_CHECKING
|
||||||
import torch
|
import torch
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
@@ -26,6 +26,7 @@ import uuid
|
|||||||
import collections
|
import collections
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@@ -34,6 +35,9 @@ import comfy.hooks
|
|||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
||||||
from comfy.comfy_types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
def string_to_seed(data):
|
def string_to_seed(data):
|
||||||
crc = 0xFFFFFFFF
|
crc = 0xFFFFFFFF
|
||||||
@@ -201,7 +205,7 @@ class MemoryCounter:
|
|||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.model = model
|
self.model: BaseModel = model
|
||||||
if not hasattr(self.model, 'device'):
|
if not hasattr(self.model, 'device'):
|
||||||
logging.debug("Model doesn't have a device attribute.")
|
logging.debug("Model doesn't have a device attribute.")
|
||||||
self.model.device = offload_device
|
self.model.device = offload_device
|
||||||
@@ -568,6 +572,14 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||||
|
|
||||||
|
def _zipper_dict_lowvram_only(self):
|
||||||
|
loading = self._load_list_lowvram_only()
|
||||||
|
|
||||||
|
|
||||||
|
def _load_list_lowvram_only(self):
|
||||||
|
loading = self._load_list()
|
||||||
|
return [x for x in loading if hasattr(x[2], "prev_comfy_cast_weights")]
|
||||||
|
|
||||||
def _load_list(self):
|
def _load_list(self):
|
||||||
loading = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
@@ -583,6 +595,35 @@ class ModelPatcher:
|
|||||||
loading.append((comfy.model_management.module_size(m), n, m, params))
|
loading.append((comfy.model_management.module_size(m), n, m, params))
|
||||||
return loading
|
return loading
|
||||||
|
|
||||||
|
def prepare_teeth(self):
|
||||||
|
ordered_list = self._load_list_lowvram_only()
|
||||||
|
prev_i = None
|
||||||
|
next_i = None
|
||||||
|
# first, create teeth on modules in list
|
||||||
|
for l in ordered_list:
|
||||||
|
m: comfy.ops.CastWeightBiasOp = l[2]
|
||||||
|
m.init_tooth(self.load_device, self.offload_device, l[1])
|
||||||
|
# create teeth linked list
|
||||||
|
for i in range(len(ordered_list)):
|
||||||
|
if i+1 == len(ordered_list):
|
||||||
|
next_i = None
|
||||||
|
else:
|
||||||
|
next_i = i+1
|
||||||
|
m: comfy.ops.CastWeightBiasOp = ordered_list[i][2]
|
||||||
|
if prev_i is not None:
|
||||||
|
m.zipper_tooth.prev_tooth = ordered_list[prev_i][2].zipper_tooth
|
||||||
|
else:
|
||||||
|
m.zipper_tooth.start = True
|
||||||
|
if next_i is not None:
|
||||||
|
m.zipper_tooth.next_tooth = ordered_list[next_i][2].zipper_tooth
|
||||||
|
prev_i = i
|
||||||
|
|
||||||
|
def clean_teeth(self):
|
||||||
|
ordered_list = self._load_list_lowvram_only()
|
||||||
|
for l in ordered_list:
|
||||||
|
m: comfy.ops.CastWeightBiasOp = l[2]
|
||||||
|
m.clean_tooth()
|
||||||
|
|
||||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
@@ -591,6 +632,8 @@ class ModelPatcher:
|
|||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
loading = self._load_list()
|
loading = self._load_list()
|
||||||
|
|
||||||
|
logging.info(f"total size of _load_list: {sum([x[0] for x in loading])}")
|
||||||
|
|
||||||
load_completely = []
|
load_completely = []
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
@@ -672,6 +715,7 @@ class ModelPatcher:
|
|||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
|
self.model.zipper_initialized = False
|
||||||
else:
|
else:
|
||||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
@@ -684,6 +728,9 @@ class ModelPatcher:
|
|||||||
self.model.model_loaded_weight_memory = mem_counter
|
self.model.model_loaded_weight_memory = mem_counter
|
||||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||||
|
|
||||||
|
if self.model.model_lowvram:
|
||||||
|
self.prepare_teeth()
|
||||||
|
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||||
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
||||||
|
|
||||||
@@ -715,6 +762,7 @@ class ModelPatcher:
|
|||||||
move_weight_functions(m, device_to)
|
move_weight_functions(m, device_to)
|
||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
|
self.clean_teeth()
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
self.model.lowvram_patch_counter = 0
|
self.model.lowvram_patch_counter = 0
|
||||||
|
|
||||||
@@ -747,6 +795,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def partially_unload(self, device_to, memory_to_free=0):
|
def partially_unload(self, device_to, memory_to_free=0):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
|
hooks_unpatched = False
|
||||||
memory_freed = 0
|
memory_freed = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
unload_list = self._load_list()
|
unload_list = self._load_list()
|
||||||
@@ -770,6 +819,10 @@ class ModelPatcher:
|
|||||||
move_weight = False
|
move_weight = False
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if not hooks_unpatched:
|
||||||
|
self.unpatch_hooks()
|
||||||
|
hooks_unpatched = True
|
||||||
|
|
||||||
if bk.inplace_update:
|
if bk.inplace_update:
|
||||||
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||||
else:
|
else:
|
||||||
@@ -799,8 +852,10 @@ class ModelPatcher:
|
|||||||
logging.debug("freed {}".format(n))
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
|
self.model.zipper_initialized = False
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.model_loaded_weight_memory -= memory_freed
|
self.model.model_loaded_weight_memory -= memory_freed
|
||||||
|
self.prepare_teeth()
|
||||||
return memory_freed
|
return memory_freed
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
@@ -1089,7 +1144,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
self.unpatch_hooks()
|
|
||||||
if hooks is not None:
|
if hooks is not None:
|
||||||
model_sd_keys = list(self.model_state_dict().keys())
|
model_sd_keys = list(self.model_state_dict().keys())
|
||||||
memory_counter = None
|
memory_counter = None
|
||||||
@@ -1100,12 +1154,16 @@ class ModelPatcher:
|
|||||||
# if have cached weights for hooks, use it
|
# if have cached weights for hooks, use it
|
||||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||||
if cached_weights is not None:
|
if cached_weights is not None:
|
||||||
|
model_sd_keys_set = set(model_sd_keys)
|
||||||
for key in cached_weights:
|
for key in cached_weights:
|
||||||
if key not in model_sd_keys:
|
if key not in model_sd_keys:
|
||||||
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
||||||
continue
|
continue
|
||||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||||
|
model_sd_keys_set.remove(key)
|
||||||
|
self.unpatch_hooks(model_sd_keys_set)
|
||||||
else:
|
else:
|
||||||
|
self.unpatch_hooks()
|
||||||
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
||||||
original_weights = None
|
original_weights = None
|
||||||
if len(relevant_patches) > 0:
|
if len(relevant_patches) > 0:
|
||||||
@@ -1116,6 +1174,8 @@ class ModelPatcher:
|
|||||||
continue
|
continue
|
||||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||||
memory_counter=memory_counter)
|
memory_counter=memory_counter)
|
||||||
|
else:
|
||||||
|
self.unpatch_hooks()
|
||||||
self.current_hooks = hooks
|
self.current_hooks = hooks
|
||||||
|
|
||||||
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||||
@@ -1172,17 +1232,23 @@ class ModelPatcher:
|
|||||||
del out_weight
|
del out_weight
|
||||||
del weight
|
del weight
|
||||||
|
|
||||||
def unpatch_hooks(self) -> None:
|
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
if len(self.hook_backup) == 0:
|
if len(self.hook_backup) == 0:
|
||||||
self.current_hooks = None
|
self.current_hooks = None
|
||||||
return
|
return
|
||||||
keys = list(self.hook_backup.keys())
|
keys = list(self.hook_backup.keys())
|
||||||
for k in keys:
|
if whitelist_keys_set:
|
||||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
for k in keys:
|
||||||
|
if k in whitelist_keys_set:
|
||||||
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
|
self.hook_backup.pop(k)
|
||||||
|
else:
|
||||||
|
for k in keys:
|
||||||
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
|
|
||||||
self.hook_backup.clear()
|
self.hook_backup.clear()
|
||||||
self.current_hooks = None
|
self.current_hooks = None
|
||||||
|
|
||||||
def clean_hooks(self):
|
def clean_hooks(self):
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
|
|||||||
124
comfy/ops.py
124
comfy/ops.py
@@ -16,7 +16,9 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import comfy.float
|
import comfy.float
|
||||||
@@ -55,6 +57,79 @@ class CastWeightBiasOp:
|
|||||||
comfy_cast_weights = False
|
comfy_cast_weights = False
|
||||||
weight_function = []
|
weight_function = []
|
||||||
bias_function = []
|
bias_function = []
|
||||||
|
zipper_init: dict = None
|
||||||
|
zipper_tooth: ZipperTooth = None
|
||||||
|
_zipper_tooth: ZipperTooth = None
|
||||||
|
|
||||||
|
def init_tooth(self, load_device, offload_device, key: str=None):
|
||||||
|
if self.zipper_tooth:
|
||||||
|
self.clean_tooth()
|
||||||
|
self.zipper_tooth = ZipperTooth(self, load_device, offload_device, key)
|
||||||
|
|
||||||
|
def clean_tooth(self):
|
||||||
|
if self.zipper_tooth:
|
||||||
|
del self.zipper_tooth
|
||||||
|
self.zipper_tooth = None
|
||||||
|
|
||||||
|
def connect_teeth(self):
|
||||||
|
if self.zipper_init is not None:
|
||||||
|
|
||||||
|
self.zipper_init[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
|
||||||
|
self.zipper_dict["prev_zipper_key"] = self.zipper_key
|
||||||
|
|
||||||
|
# def zipper_connect(self):
|
||||||
|
# if self.zipper_dict is not None:
|
||||||
|
# self.zipper_dict[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
|
||||||
|
# self.zipper_dict["prev_zipper_key"] = self.zipper_key
|
||||||
|
|
||||||
|
class ZipperTooth:
|
||||||
|
def __init__(self, op: CastWeightBiasOp, load_device, offload_device, key: str=None):
|
||||||
|
self.op = op
|
||||||
|
self.key: str = key
|
||||||
|
self.weight_preloaded: torch.Tensor = None
|
||||||
|
self.bias_preloaded: torch.Tensor = None
|
||||||
|
self.load_device = load_device
|
||||||
|
self.offload_device = offload_device
|
||||||
|
self.start = False
|
||||||
|
|
||||||
|
self.prev_tooth: ZipperTooth = None
|
||||||
|
self.next_tooth: ZipperTooth = None
|
||||||
|
|
||||||
|
def get_bias_weight(self, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
|
||||||
|
try:
|
||||||
|
if self.start:
|
||||||
|
return cast_bias_weight(self.op, input, dtype, device, bias_dtype)
|
||||||
|
return self.weight_preloaded, self.bias_preloaded
|
||||||
|
finally:
|
||||||
|
# if self.prev_tooth:
|
||||||
|
# self.prev_tooth.offload_previous(0)
|
||||||
|
self.next_tooth.preload_next(0, input, dtype, device, bias_dtype)
|
||||||
|
|
||||||
|
def preload_next(self, teeth_count=1, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
|
||||||
|
# TODO: queue load of tensors
|
||||||
|
if input is not None:
|
||||||
|
if dtype is None:
|
||||||
|
dtype = input.dtype
|
||||||
|
if bias_dtype is None:
|
||||||
|
bias_dtype = dtype
|
||||||
|
if device is None:
|
||||||
|
device = input.device
|
||||||
|
|
||||||
|
non_blocking = comfy.model_management.device_supports_non_blocking(self.load_device)
|
||||||
|
|
||||||
|
if self.op.bias is not None:
|
||||||
|
self.bias_preloaded = comfy.model_management.cast_to(self.op.bias, bias_dtype, device, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
self.weight_preloaded = comfy.model_management.cast_to(self.op.weight, dtype, device, non_blocking=non_blocking)
|
||||||
|
if self.next_tooth and teeth_count:
|
||||||
|
self.next_tooth.preload_next(teeth_count-1)
|
||||||
|
|
||||||
|
def offload_previous(self, teeth_count=1):
|
||||||
|
# TODO: queue offload of tensors
|
||||||
|
self.weight_preloaded = None
|
||||||
|
self.bias_preloaded = None
|
||||||
|
if self.prev_tooth and teeth_count:
|
||||||
|
self.prev_tooth.offload_previous(teeth_count-1)
|
||||||
|
|
||||||
class disable_weight_init:
|
class disable_weight_init:
|
||||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||||
@@ -62,7 +137,11 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
#if self.zipper_init:
|
||||||
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@@ -76,7 +155,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return self._conv_forward(input, weight, bias)
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@@ -90,7 +172,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return self._conv_forward(input, weight, bias)
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@@ -104,7 +189,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return self._conv_forward(input, weight, bias)
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@@ -118,7 +206,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@@ -133,7 +224,10 @@ class disable_weight_init:
|
|||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
if self.weight is not None:
|
if self.weight is not None:
|
||||||
weight, bias = cast_bias_weight(self, input)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
else:
|
else:
|
||||||
weight = None
|
weight = None
|
||||||
bias = None
|
bias = None
|
||||||
@@ -155,7 +249,10 @@ class disable_weight_init:
|
|||||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||||
num_spatial_dims, self.dilation)
|
num_spatial_dims, self.dilation)
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.conv_transpose2d(
|
return torch.nn.functional.conv_transpose2d(
|
||||||
input, weight, bias, self.stride, self.padding,
|
input, weight, bias, self.stride, self.padding,
|
||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
@@ -176,7 +273,10 @@ class disable_weight_init:
|
|||||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||||
num_spatial_dims, self.dilation)
|
num_spatial_dims, self.dilation)
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.conv_transpose1d(
|
return torch.nn.functional.conv_transpose1d(
|
||||||
input, weight, bias, self.stride, self.padding,
|
input, weight, bias, self.stride, self.padding,
|
||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
@@ -196,7 +296,10 @@ class disable_weight_init:
|
|||||||
output_dtype = out_dtype
|
output_dtype = out_dtype
|
||||||
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
||||||
out_dtype = None
|
out_dtype = None
|
||||||
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(device=input.device, dtype=out_dtype)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
||||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@@ -308,6 +411,7 @@ class fp8_ops(manual_cast):
|
|||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||||
|
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
||||||
class scaled_fp8_op(manual_cast):
|
class scaled_fp8_op(manual_cast):
|
||||||
class Linear(manual_cast.Linear):
|
class Linear(manual_cast.Linear):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@@ -358,7 +462,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
if scaled_fp8 is not None:
|
if scaled_fp8 is not None:
|
||||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
|
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
fp8_compute and
|
fp8_compute and
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ if TYPE_CHECKING:
|
|||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
from comfy.model_base import BaseModel
|
from comfy.model_base import BaseModel
|
||||||
from comfy.controlnet import ControlBase
|
from comfy.controlnet import ControlBase
|
||||||
|
from comfy.ops import CastWeightBiasOp
|
||||||
import torch
|
import torch
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import collections
|
import collections
|
||||||
@@ -18,6 +19,7 @@ import comfy.patcher_extension
|
|||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
|
||||||
def add_area_dims(area, num_dims):
|
def add_area_dims(area, num_dims):
|
||||||
@@ -360,15 +362,38 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
|
|||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns denoised
|
#Returns denoised
|
||||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
def sampling_function(model: BaseModel, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||||
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
||||||
uncond_ = None
|
uncond_ = None
|
||||||
else:
|
else:
|
||||||
uncond_ = uncond
|
uncond_ = uncond
|
||||||
|
|
||||||
|
do_cleanup = False
|
||||||
|
if "weight_zipper" not in model_options:
|
||||||
|
do_cleanup = True
|
||||||
|
#zipper_dict = {}
|
||||||
|
model_options["weight_zipper"] = True
|
||||||
|
loaded_modules = model.current_patcher._load_list_lowvram_only()
|
||||||
|
low_m = [x for x in loaded_modules if hasattr(x[2], "prev_comfy_cast_weights")]
|
||||||
|
sum_m = sum([x[0] for x in low_m])
|
||||||
|
for l in loaded_modules:
|
||||||
|
m: CastWeightBiasOp = l[2]
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
m.zipper_tooth = comfy.ops.ZipperTooth
|
||||||
|
#m.zipper_dict = zipper_dict
|
||||||
|
m.zipper_key = l[1]
|
||||||
|
|
||||||
conds = [cond, uncond_]
|
conds = [cond, uncond_]
|
||||||
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
||||||
|
|
||||||
|
if do_cleanup:
|
||||||
|
zzz = 20
|
||||||
|
for l in loaded_modules:
|
||||||
|
m: CastWeightBiasOp = l[2]
|
||||||
|
if hasattr(l[2], "comfy_cast_weights"):
|
||||||
|
#m.zipper_dict = None
|
||||||
|
m.zipper_key = None
|
||||||
|
|
||||||
for fn in model_options.get("sampler_pre_cfg_function", []):
|
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||||
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
|
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
|
||||||
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
|
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
|
||||||
@@ -710,7 +735,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
|||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
"gradient_estimation"]
|
"gradient_estimation", "er_sde"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
|||||||
15
comfy/sd.py
15
comfy/sd.py
@@ -440,6 +440,10 @@ class VAE:
|
|||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||||
|
|
||||||
|
def throw_exception_if_invalid(self):
|
||||||
|
if self.first_stage_model is None:
|
||||||
|
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||||
|
|
||||||
def vae_encode_crop_pixels(self, pixels):
|
def vae_encode_crop_pixels(self, pixels):
|
||||||
downscale_ratio = self.spacial_compression_encode()
|
downscale_ratio = self.spacial_compression_encode()
|
||||||
|
|
||||||
@@ -495,6 +499,7 @@ class VAE:
|
|||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = None
|
pixel_samples = None
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
@@ -525,6 +530,7 @@ class VAE:
|
|||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
|
self.throw_exception_if_invalid()
|
||||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
dims = samples.ndim - 2
|
dims = samples.ndim - 2
|
||||||
@@ -553,6 +559,7 @@ class VAE:
|
|||||||
return output.movedim(1, -1)
|
return output.movedim(1, -1)
|
||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||||
@@ -585,6 +592,7 @@ class VAE:
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
dims = self.latent_dim
|
dims = self.latent_dim
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
@@ -899,7 +907,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
return None
|
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||||
|
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
|
||||||
|
if diffusion_model is None:
|
||||||
|
return None
|
||||||
|
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||||
|
|
||||||
|
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if model_config.scaled_fp8 is not None:
|
if model_config.scaled_fp8 is not None:
|
||||||
|
|||||||
@@ -228,6 +228,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
if pad_extra > 0:
|
if pad_extra > 0:
|
||||||
padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32)
|
padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32)
|
||||||
tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1)
|
tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1)
|
||||||
|
attention_mask = attention_mask + [0] * pad_extra
|
||||||
|
|
||||||
embeds_out.append(tokens_embed)
|
embeds_out.append(tokens_embed)
|
||||||
attention_masks.append(attention_mask)
|
attention_masks.append(attention_mask)
|
||||||
|
|||||||
@@ -931,7 +931,7 @@ class WAN21_T2V(supported_models_base.BASE):
|
|||||||
|
|
||||||
memory_usage_factor = 1.0
|
memory_usage_factor = 1.0
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class HunyuanVideoTokenizer:
|
|||||||
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
|
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
|
||||||
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
|
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, **kwargs):
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||||
|
|
||||||
@@ -56,7 +56,7 @@ class HunyuanVideoTokenizer:
|
|||||||
for i in range(len(r)):
|
for i in range(len(r)):
|
||||||
if r[i][0] == 128257:
|
if r[i][0] == 128257:
|
||||||
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
||||||
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
|
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image", "image_interleave": image_interleave},) + r[i][1:]
|
||||||
embed_count += 1
|
embed_count += 1
|
||||||
out["llama"] = llama_text_tokens
|
out["llama"] = llama_text_tokens
|
||||||
return out
|
return out
|
||||||
@@ -92,10 +92,10 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
|||||||
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
|
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||||
|
|
||||||
template_end = 0
|
template_end = 0
|
||||||
image_start = None
|
extra_template_end = 0
|
||||||
image_end = None
|
|
||||||
extra_sizes = 0
|
extra_sizes = 0
|
||||||
user_end = 9999999999999
|
user_end = 9999999999999
|
||||||
|
images = []
|
||||||
|
|
||||||
tok_pairs = token_weight_pairs_llama[0]
|
tok_pairs = token_weight_pairs_llama[0]
|
||||||
for i, v in enumerate(tok_pairs):
|
for i, v in enumerate(tok_pairs):
|
||||||
@@ -112,22 +112,28 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
if elem.get("original_type") == "image":
|
if elem.get("original_type") == "image":
|
||||||
elem_size = elem.get("data").shape[0]
|
elem_size = elem.get("data").shape[0]
|
||||||
if image_start is None:
|
if template_end > 0:
|
||||||
|
if user_end == -1:
|
||||||
|
extra_template_end += elem_size - 1
|
||||||
|
else:
|
||||||
image_start = i + extra_sizes
|
image_start = i + extra_sizes
|
||||||
image_end = i + elem_size + extra_sizes
|
image_end = i + elem_size + extra_sizes
|
||||||
extra_sizes += elem_size - 1
|
images.append((image_start, image_end, elem.get("image_interleave", 1)))
|
||||||
|
extra_sizes += elem_size - 1
|
||||||
|
|
||||||
if llama_out.shape[1] > (template_end + 2):
|
if llama_out.shape[1] > (template_end + 2):
|
||||||
if tok_pairs[template_end + 1][0] == 271:
|
if tok_pairs[template_end + 1][0] == 271:
|
||||||
template_end += 2
|
template_end += 2
|
||||||
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes]
|
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
||||||
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes]
|
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
||||||
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
|
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
|
||||||
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
|
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
|
||||||
|
|
||||||
if image_start is not None:
|
if len(images) > 0:
|
||||||
image_output = llama_out[:, image_start: image_end]
|
out = []
|
||||||
llama_output = torch.cat([image_output[:, ::2], llama_output], dim=1)
|
for i in images:
|
||||||
|
out.append(llama_out[:, i[0]: i[1]: i[2]])
|
||||||
|
llama_output = torch.cat(out + [llama_output], dim=1)
|
||||||
|
|
||||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
return llama_output, l_pooled, llama_extra_out
|
return llama_output, l_pooled, llama_extra_out
|
||||||
|
|||||||
@@ -57,17 +57,17 @@ class TextEncodeHunyuanVideo_ImageToVideo:
|
|||||||
"clip": ("CLIP", ),
|
"clip": ("CLIP", ),
|
||||||
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||||
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning"
|
CATEGORY = "advanced/conditioning"
|
||||||
|
|
||||||
def encode(self, clip, clip_vision_output, prompt):
|
def encode(self, clip, clip_vision_output, prompt, image_interleave):
|
||||||
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected)
|
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
|
||||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||||
|
|
||||||
|
|
||||||
class HunyuanImageToVideo:
|
class HunyuanImageToVideo:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@@ -77,6 +77,7 @@ class HunyuanImageToVideo:
|
|||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
|
||||||
},
|
},
|
||||||
"optional": {"start_image": ("IMAGE", ),
|
"optional": {"start_image": ("IMAGE", ),
|
||||||
}}
|
}}
|
||||||
@@ -87,8 +88,10 @@ class HunyuanImageToVideo:
|
|||||||
|
|
||||||
CATEGORY = "conditioning/video_models"
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
def encode(self, positive, vae, width, height, length, batch_size, start_image=None):
|
def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
out_latent = {}
|
||||||
|
|
||||||
if start_image is not None:
|
if start_image is not None:
|
||||||
start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
|
||||||
@@ -96,13 +99,20 @@ class HunyuanImageToVideo:
|
|||||||
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||||
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
if guidance_type == "v1 (concat)":
|
||||||
|
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
|
||||||
|
else:
|
||||||
|
cond = {'guiding_frame_index': 0}
|
||||||
|
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
|
||||||
|
out_latent["noise_mask"] = mask
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, cond)
|
||||||
|
|
||||||
out_latent = {}
|
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, out_latent)
|
return (positive, out_latent)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||||
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
||||||
|
|||||||
@@ -19,8 +19,6 @@ class Load3D():
|
|||||||
"image": ("LOAD_3D", {}),
|
"image": ("LOAD_3D", {}),
|
||||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
|
||||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||||
@@ -55,8 +53,6 @@ class Load3DAnimation():
|
|||||||
"image": ("LOAD_3D_ANIMATION", {}),
|
"image": ("LOAD_3D_ANIMATION", {}),
|
||||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
|
||||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||||
@@ -82,8 +78,6 @@ class Preview3D():
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
|
||||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
@@ -102,8 +96,6 @@ class Preview3DAnimation():
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
|
||||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|||||||
@@ -99,12 +99,13 @@ class LTXVAddGuide:
|
|||||||
"negative": ("CONDITIONING", ),
|
"negative": ("CONDITIONING", ),
|
||||||
"vae": ("VAE",),
|
"vae": ("VAE",),
|
||||||
"latent": ("LATENT",),
|
"latent": ("LATENT",),
|
||||||
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \
|
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames."
|
||||||
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
|
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
|
||||||
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
|
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
|
||||||
"tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \
|
"tooltip": "Frame index to start the conditioning at. For single-frame images or "
|
||||||
"If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \
|
"videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
|
||||||
"Negative values are counted from the end of the video."}),
|
"frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
|
||||||
|
"the nearest multiple of 8. Negative values are counted from the end of the video."}),
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -127,12 +128,13 @@ class LTXVAddGuide:
|
|||||||
t = vae.encode(encode_pixels)
|
t = vae.encode(encode_pixels)
|
||||||
return encode_pixels, t
|
return encode_pixels, t
|
||||||
|
|
||||||
def get_latent_index(self, cond, latent_length, frame_idx, scale_factors):
|
def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors):
|
||||||
time_scale_factor, _, _ = scale_factors
|
time_scale_factor, _, _ = scale_factors
|
||||||
_, num_keyframes = get_keyframe_idxs(cond)
|
_, num_keyframes = get_keyframe_idxs(cond)
|
||||||
latent_count = latent_length - num_keyframes
|
latent_count = latent_length - num_keyframes
|
||||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0)
|
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||||
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
|
if guide_length > 1:
|
||||||
|
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
|
||||||
|
|
||||||
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
|
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
|
||||||
|
|
||||||
@@ -191,7 +193,7 @@ class LTXVAddGuide:
|
|||||||
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
||||||
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
|
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
|
||||||
|
|
||||||
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors)
|
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||||
|
|
||||||
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
|
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.23"
|
__version__ = "0.3.26"
|
||||||
|
|||||||
@@ -634,6 +634,13 @@ def validate_inputs(prompt, item, validated):
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
# Unwraps values wrapped in __value__ key. This is used to pass
|
||||||
|
# list widget value to execution, as by default list value is
|
||||||
|
# reserved to represent the connection between nodes.
|
||||||
|
if isinstance(val, dict) and "__value__" in val:
|
||||||
|
val = val["__value__"]
|
||||||
|
inputs[x] = val
|
||||||
|
|
||||||
if type_input == "INT":
|
if type_input == "INT":
|
||||||
val = int(val)
|
val = int(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
|
|||||||
6
main.py
6
main.py
@@ -139,6 +139,7 @@ from server import BinaryEventTypes
|
|||||||
import nodes
|
import nodes
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfyui_version
|
import comfyui_version
|
||||||
|
import app.logger
|
||||||
|
|
||||||
|
|
||||||
def cuda_malloc_warning():
|
def cuda_malloc_warning():
|
||||||
@@ -295,9 +296,12 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Running directly, just start ComfyUI.
|
# Running directly, just start ComfyUI.
|
||||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||||
|
|
||||||
event_loop, _, start_all_func = start_comfyui()
|
event_loop, _, start_all_func = start_comfyui()
|
||||||
try:
|
try:
|
||||||
event_loop.run_until_complete(start_all_func())
|
x = start_all_func()
|
||||||
|
app.logger.print_startup_warnings()
|
||||||
|
event_loop.run_until_complete(x)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging.info("\nStopped server")
|
logging.info("\nStopped server")
|
||||||
|
|
||||||
|
|||||||
12
nodes.py
12
nodes.py
@@ -489,7 +489,7 @@ class SaveLatent:
|
|||||||
file = os.path.join(full_output_folder, file)
|
file = os.path.join(full_output_folder, file)
|
||||||
|
|
||||||
output = {}
|
output = {}
|
||||||
output["latent_tensor"] = samples["samples"]
|
output["latent_tensor"] = samples["samples"].contiguous()
|
||||||
output["latent_format_version_0"] = torch.tensor([])
|
output["latent_format_version_0"] = torch.tensor([])
|
||||||
|
|
||||||
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
||||||
@@ -770,6 +770,7 @@ class VAELoader:
|
|||||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
sd = comfy.utils.load_torch_file(vae_path)
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
vae = comfy.sd.VAE(sd=sd)
|
vae = comfy.sd.VAE(sd=sd)
|
||||||
|
vae.throw_exception_if_invalid()
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
||||||
class ControlNetLoader:
|
class ControlNetLoader:
|
||||||
@@ -1785,14 +1786,7 @@ class LoadImageOutput(LoadImage):
|
|||||||
|
|
||||||
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
|
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
FUNCTION = "load_image_output"
|
FUNCTION = "load_image"
|
||||||
|
|
||||||
def load_image_output(self, image):
|
|
||||||
return self.load_image(f"{image} [output]")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def VALIDATE_INPUTS(s, image):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class ImageScale:
|
class ImageScale:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.23"
|
version = "0.3.26"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.10.17
|
comfyui-frontend-package==1.12.14
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ def test_get_release_invalid_version(mock_provider):
|
|||||||
def test_init_frontend_default():
|
def test_init_frontend_default():
|
||||||
version_string = DEFAULT_VERSION_STRING
|
version_string = DEFAULT_VERSION_STRING
|
||||||
frontend_path = FrontendManager.init_frontend(version_string)
|
frontend_path = FrontendManager.init_frontend(version_string)
|
||||||
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
|
assert frontend_path == FrontendManager.default_frontend_path()
|
||||||
|
|
||||||
|
|
||||||
def test_init_frontend_invalid_version():
|
def test_init_frontend_invalid_version():
|
||||||
@@ -84,24 +84,29 @@ def test_init_frontend_invalid_provider():
|
|||||||
with pytest.raises(HTTPError):
|
with pytest.raises(HTTPError):
|
||||||
FrontendManager.init_frontend_unsafe(version_string)
|
FrontendManager.init_frontend_unsafe(version_string)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_os_functions():
|
def mock_os_functions():
|
||||||
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
|
with (
|
||||||
patch('app.frontend_management.os.listdir') as mock_listdir, \
|
patch("app.frontend_management.os.makedirs") as mock_makedirs,
|
||||||
patch('app.frontend_management.os.rmdir') as mock_rmdir:
|
patch("app.frontend_management.os.listdir") as mock_listdir,
|
||||||
|
patch("app.frontend_management.os.rmdir") as mock_rmdir,
|
||||||
|
):
|
||||||
mock_listdir.return_value = [] # Simulate empty directory
|
mock_listdir.return_value = [] # Simulate empty directory
|
||||||
yield mock_makedirs, mock_listdir, mock_rmdir
|
yield mock_makedirs, mock_listdir, mock_rmdir
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_download():
|
def mock_download():
|
||||||
with patch('app.frontend_management.download_release_asset_zip') as mock:
|
with patch("app.frontend_management.download_release_asset_zip") as mock:
|
||||||
mock.side_effect = Exception("Download failed") # Simulate download failure
|
mock.side_effect = Exception("Download failed") # Simulate download failure
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
def test_finally_block(mock_os_functions, mock_download, mock_provider):
|
def test_finally_block(mock_os_functions, mock_download, mock_provider):
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
|
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
|
||||||
version_string = 'test-owner/test-repo@1.0.0'
|
version_string = "test-owner/test-repo@1.0.0"
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
@@ -128,3 +133,42 @@ def test_parse_version_string_invalid():
|
|||||||
version_string = "invalid"
|
version_string = "invalid"
|
||||||
with pytest.raises(argparse.ArgumentTypeError):
|
with pytest.raises(argparse.ArgumentTypeError):
|
||||||
FrontendManager.parse_version_string(version_string)
|
FrontendManager.parse_version_string(version_string)
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_frontend_default_with_mocks():
|
||||||
|
# Arrange
|
||||||
|
version_string = DEFAULT_VERSION_STRING
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with (
|
||||||
|
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||||
|
patch.object(
|
||||||
|
FrontendManager, "default_frontend_path", return_value="/mocked/path"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
frontend_path = FrontendManager.init_frontend(version_string)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert frontend_path == "/mocked/path"
|
||||||
|
mock_check.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_frontend_fallback_on_error():
|
||||||
|
# Arrange
|
||||||
|
version_string = "test-owner/test-repo@1.0.0"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
FrontendManager, "init_frontend_unsafe", side_effect=Exception("Test error")
|
||||||
|
),
|
||||||
|
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||||
|
patch.object(
|
||||||
|
FrontendManager, "default_frontend_path", return_value="/default/path"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
frontend_path = FrontendManager.init_frontend(version_string)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert frontend_path == "/default/path"
|
||||||
|
mock_check.assert_called_once()
|
||||||
|
|||||||
Reference in New Issue
Block a user