Compare commits
41 Commits
v0.3.25
...
video_outp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60b459bb4c | ||
|
|
3b19fc76e3 | ||
|
|
50614f1b79 | ||
|
|
6dc7b0bfe3 | ||
|
|
e8e990d6b8 | ||
|
|
2e24a15905 | ||
|
|
fd5297131f | ||
|
|
55a1b09ddc | ||
|
|
3c3988df45 | ||
|
|
7ebd8087ff | ||
|
|
c624c29d66 | ||
|
|
a2448fc527 | ||
|
|
6a0daa79b6 | ||
|
|
9c98c6358b | ||
|
|
7aceb9f91c | ||
|
|
35504e2f93 | ||
|
|
299436cfed | ||
|
|
52e566d2bc | ||
|
|
9b6cd9b874 | ||
|
|
3fc688aebd | ||
|
|
f4411250f3 | ||
|
|
d2a0fb6bb0 | ||
|
|
01015bff16 | ||
|
|
2330754b0e | ||
|
|
bc219a6487 | ||
|
|
94689766ad | ||
|
|
cfbe4b49ca | ||
|
|
ca8efab79f | ||
|
|
65ea778a5e | ||
|
|
db9f2a34fc | ||
|
|
7946049794 | ||
|
|
6f6349b6a7 | ||
|
|
1f138dd382 | ||
|
|
b779349b55 | ||
|
|
35e2dcf5d7 | ||
|
|
67c7184b74 | ||
|
|
6f8e766509 | ||
|
|
e1da98a14a | ||
|
|
a73410aafa | ||
|
|
9aac21f894 | ||
|
|
528d1b3563 |
@@ -19,5 +19,6 @@
|
|||||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
|
|
||||||
# Extra nodes
|
# Node developers
|
||||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
|
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||||
|
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||||
|
|||||||
@@ -11,33 +11,44 @@ from dataclasses import dataclass
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypedDict, Optional
|
from typing import TypedDict, Optional
|
||||||
|
from importlib.metadata import version
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||||
|
import app.logger
|
||||||
|
|
||||||
|
# The path to the requirements.txt file
|
||||||
|
req_path = Path(__file__).parents[1] / "requirements.txt"
|
||||||
|
|
||||||
def frontend_install_warning_message():
|
def frontend_install_warning_message():
|
||||||
req_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))
|
"""The warning message to display when the frontend version is not up to date."""
|
||||||
|
|
||||||
extra = ""
|
extra = ""
|
||||||
if sys.flags.no_user_site:
|
if sys.flags.no_user_site:
|
||||||
extra = "-s "
|
extra = "-s "
|
||||||
return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem"
|
return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem"
|
||||||
|
|
||||||
try:
|
|
||||||
import comfyui_frontend_package
|
|
||||||
except ImportError:
|
|
||||||
# TODO: Remove the check after roll out of 0.3.16
|
|
||||||
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n")
|
|
||||||
exit(-1)
|
|
||||||
|
|
||||||
|
def check_frontend_version():
|
||||||
|
"""Check if the frontend version is up to date."""
|
||||||
|
|
||||||
|
def parse_version(version: str) -> tuple[int, int, int]:
|
||||||
|
return tuple(map(int, version.split(".")))
|
||||||
|
|
||||||
|
try:
|
||||||
|
frontend_version_str = version("comfyui-frontend-package")
|
||||||
|
frontend_version = parse_version(frontend_version_str)
|
||||||
|
with open(req_path, "r", encoding="utf-8") as f:
|
||||||
|
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||||
|
if frontend_version < required_frontend:
|
||||||
|
app.logger.log_startup_warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message()))
|
||||||
|
else:
|
||||||
|
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to check frontend version: {e}")
|
||||||
|
|
||||||
try:
|
|
||||||
frontend_version = tuple(map(int, comfyui_frontend_package.__version__.split(".")))
|
|
||||||
except:
|
|
||||||
frontend_version = (0,)
|
|
||||||
pass
|
|
||||||
|
|
||||||
REQUEST_TIMEOUT = 10 # seconds
|
REQUEST_TIMEOUT = 10 # seconds
|
||||||
|
|
||||||
@@ -133,9 +144,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
class FrontendManager:
|
class FrontendManager:
|
||||||
DEFAULT_FRONTEND_PATH = str(importlib.resources.files(comfyui_frontend_package) / "static")
|
|
||||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default_frontend_path(cls) -> str:
|
||||||
|
try:
|
||||||
|
import comfyui_frontend_package
|
||||||
|
return str(importlib.resources.files(comfyui_frontend_package) / "static")
|
||||||
|
except ImportError:
|
||||||
|
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n")
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
@@ -172,7 +191,8 @@ class FrontendManager:
|
|||||||
main error source might be request timeout or invalid URL.
|
main error source might be request timeout or invalid URL.
|
||||||
"""
|
"""
|
||||||
if version_string == DEFAULT_VERSION_STRING:
|
if version_string == DEFAULT_VERSION_STRING:
|
||||||
return cls.DEFAULT_FRONTEND_PATH
|
check_frontend_version()
|
||||||
|
return cls.default_frontend_path()
|
||||||
|
|
||||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||||
|
|
||||||
@@ -225,4 +245,5 @@ class FrontendManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Failed to initialize frontend: %s", e)
|
logging.error("Failed to initialize frontend: %s", e)
|
||||||
logging.info("Falling back to the default frontend.")
|
logging.info("Falling back to the default frontend.")
|
||||||
return cls.DEFAULT_FRONTEND_PATH
|
check_frontend_version()
|
||||||
|
return cls.default_frontend_path()
|
||||||
|
|||||||
@@ -82,3 +82,17 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
|
|||||||
logger.addHandler(stdout_handler)
|
logger.addHandler(stdout_handler)
|
||||||
|
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
|
||||||
|
STARTUP_WARNINGS = []
|
||||||
|
|
||||||
|
|
||||||
|
def log_startup_warning(msg):
|
||||||
|
logging.warning(msg)
|
||||||
|
STARTUP_WARNINGS.append(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def print_startup_warnings():
|
||||||
|
for s in STARTUP_WARNINGS:
|
||||||
|
logging.warning(s)
|
||||||
|
STARTUP_WARNINGS.clear()
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
|
|||||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||||
|
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import comfy.model_patcher
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.clip_model
|
import comfy.clip_model
|
||||||
|
import comfy.image_encoders.dino2
|
||||||
|
|
||||||
class Output:
|
class Output:
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
@@ -34,6 +35,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
|
|||||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||||
|
|
||||||
|
IMAGE_ENCODERS = {
|
||||||
|
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
|
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
|
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||||
|
}
|
||||||
|
|
||||||
class ClipVisionModel():
|
class ClipVisionModel():
|
||||||
def __init__(self, json_config):
|
def __init__(self, json_config):
|
||||||
with open(json_config) as f:
|
with open(json_config) as f:
|
||||||
@@ -42,10 +49,11 @@ class ClipVisionModel():
|
|||||||
self.image_size = config.get("image_size", 224)
|
self.image_size = config.get("image_size", 224)
|
||||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||||
|
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
|
||||||
self.load_device = comfy.model_management.text_encoder_device()
|
self.load_device = comfy.model_management.text_encoder_device()
|
||||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||||
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
@@ -111,6 +119,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||||
else:
|
else:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||||
|
elif "embeddings.patch_embeddings.projection.weight" in sd:
|
||||||
|
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Literal, TypedDict
|
from typing import Literal, TypedDict
|
||||||
|
from typing_extensions import NotRequired
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@@ -26,6 +27,7 @@ class IO(StrEnum):
|
|||||||
BOOLEAN = "BOOLEAN"
|
BOOLEAN = "BOOLEAN"
|
||||||
INT = "INT"
|
INT = "INT"
|
||||||
FLOAT = "FLOAT"
|
FLOAT = "FLOAT"
|
||||||
|
COMBO = "COMBO"
|
||||||
CONDITIONING = "CONDITIONING"
|
CONDITIONING = "CONDITIONING"
|
||||||
SAMPLER = "SAMPLER"
|
SAMPLER = "SAMPLER"
|
||||||
SIGMAS = "SIGMAS"
|
SIGMAS = "SIGMAS"
|
||||||
@@ -66,6 +68,7 @@ class IO(StrEnum):
|
|||||||
b = frozenset(value.split(","))
|
b = frozenset(value.split(","))
|
||||||
return not (b.issubset(a) or a.issubset(b))
|
return not (b.issubset(a) or a.issubset(b))
|
||||||
|
|
||||||
|
|
||||||
class RemoteInputOptions(TypedDict):
|
class RemoteInputOptions(TypedDict):
|
||||||
route: str
|
route: str
|
||||||
"""The route to the remote source."""
|
"""The route to the remote source."""
|
||||||
@@ -80,6 +83,14 @@ class RemoteInputOptions(TypedDict):
|
|||||||
refresh: int
|
refresh: int
|
||||||
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
|
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
|
||||||
|
|
||||||
|
|
||||||
|
class MultiSelectOptions(TypedDict):
|
||||||
|
placeholder: NotRequired[str]
|
||||||
|
"""The placeholder text to display in the multi-select widget when no items are selected."""
|
||||||
|
chip: NotRequired[bool]
|
||||||
|
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
|
||||||
|
|
||||||
|
|
||||||
class InputTypeOptions(TypedDict):
|
class InputTypeOptions(TypedDict):
|
||||||
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
||||||
|
|
||||||
@@ -133,9 +144,22 @@ class InputTypeOptions(TypedDict):
|
|||||||
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
||||||
"""
|
"""
|
||||||
remote: RemoteInputOptions
|
remote: RemoteInputOptions
|
||||||
"""Specifies the configuration for a remote input."""
|
"""Specifies the configuration for a remote input.
|
||||||
|
Available after ComfyUI frontend v1.9.7
|
||||||
|
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
|
||||||
control_after_generate: bool
|
control_after_generate: bool
|
||||||
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
|
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
|
||||||
|
options: NotRequired[list[str | int | float]]
|
||||||
|
"""COMBO type only. Specifies the selectable options for the combo widget.
|
||||||
|
Prefer:
|
||||||
|
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
|
||||||
|
Over:
|
||||||
|
[["Option 1", "Option 2", "Option 3"]]
|
||||||
|
"""
|
||||||
|
multi_select: NotRequired[MultiSelectOptions]
|
||||||
|
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||||
|
Available after ComfyUI frontend v1.13.4
|
||||||
|
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||||
|
|
||||||
|
|
||||||
class HiddenInputTypeDict(TypedDict):
|
class HiddenInputTypeDict(TypedDict):
|
||||||
|
|||||||
141
comfy/image_encoders/dino2.py
Normal file
141
comfy/image_encoders/dino2.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
import torch
|
||||||
|
from comfy.text_encoders.bert import BertAttention
|
||||||
|
import comfy.model_management
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2AttentionOutput(torch.nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.dense(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2AttentionBlock(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
|
||||||
|
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
||||||
|
|
||||||
|
def forward(self, x, mask, optimized_attention):
|
||||||
|
return self.output(self.attention(x, mask, optimized_attention))
|
||||||
|
|
||||||
|
|
||||||
|
class LayerScale(torch.nn.Module):
|
||||||
|
def __init__(self, dim, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLUFFN(torch.nn.Module):
|
||||||
|
def __init__(self, dim, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
in_features = out_features = dim
|
||||||
|
hidden_features = int(dim * 4)
|
||||||
|
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||||
|
|
||||||
|
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
|
||||||
|
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.weights_in(x)
|
||||||
|
x1, x2 = x.chunk(2, dim=-1)
|
||||||
|
x = torch.nn.functional.silu(x1) * x2
|
||||||
|
return self.weights_out(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2Block(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||||
|
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||||
|
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||||
|
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||||
|
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x, optimized_attention):
|
||||||
|
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
|
||||||
|
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2Encoder(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
||||||
|
|
||||||
|
def forward(self, x, intermediate_output=None):
|
||||||
|
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||||
|
|
||||||
|
if intermediate_output is not None:
|
||||||
|
if intermediate_output < 0:
|
||||||
|
intermediate_output = len(self.layer) + intermediate_output
|
||||||
|
|
||||||
|
intermediate = None
|
||||||
|
for i, l in enumerate(self.layer):
|
||||||
|
x = l(x, optimized_attention)
|
||||||
|
if i == intermediate_output:
|
||||||
|
intermediate = x.clone()
|
||||||
|
return x, intermediate
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2PatchEmbeddings(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.projection = operations.Conv2d(
|
||||||
|
in_channels=num_channels,
|
||||||
|
out_channels=dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, pixel_values):
|
||||||
|
return self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2Embeddings(torch.nn.Module):
|
||||||
|
def __init__(self, dim, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
patch_size = 14
|
||||||
|
image_size = 518
|
||||||
|
|
||||||
|
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
||||||
|
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
|
||||||
|
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
def forward(self, pixel_values):
|
||||||
|
x = self.patch_embeddings(pixel_values)
|
||||||
|
# TODO: mask_token?
|
||||||
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
|
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Dinov2Model(torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
num_layers = config_dict["num_hidden_layers"]
|
||||||
|
dim = config_dict["hidden_size"]
|
||||||
|
heads = config_dict["num_attention_heads"]
|
||||||
|
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||||
|
|
||||||
|
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||||
|
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
||||||
|
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||||
|
x = self.embeddings(pixel_values)
|
||||||
|
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
||||||
|
x = self.layernorm(x)
|
||||||
|
pooled_output = x[:, 0, :]
|
||||||
|
return x, i, pooled_output, None
|
||||||
21
comfy/image_encoders/dino2_giant.json
Normal file
21
comfy/image_encoders/dino2_giant.json
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"attention_probs_dropout_prob": 0.0,
|
||||||
|
"drop_path_rate": 0.0,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout_prob": 0.0,
|
||||||
|
"hidden_size": 1536,
|
||||||
|
"image_size": 518,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"layer_norm_eps": 1e-06,
|
||||||
|
"layerscale_value": 1.0,
|
||||||
|
"mlp_ratio": 4,
|
||||||
|
"model_type": "dinov2",
|
||||||
|
"num_attention_heads": 24,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_hidden_layers": 40,
|
||||||
|
"patch_size": 14,
|
||||||
|
"qkv_bias": true,
|
||||||
|
"use_swiglu_ffn": true,
|
||||||
|
"image_mean": [0.485, 0.456, 0.406],
|
||||||
|
"image_std": [0.229, 0.224, 0.225]
|
||||||
|
}
|
||||||
@@ -688,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
|||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
sigma_fn = lambda t: t.neg().exp()
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
t_fn = lambda sigma: sigma.log().neg()
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
@@ -762,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
if solver_type not in {'heun', 'midpoint'}:
|
if solver_type not in {'heun', 'midpoint'}:
|
||||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
old_denoised = None
|
old_denoised = None
|
||||||
@@ -808,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
denoised_1, denoised_2 = None, None
|
denoised_1, denoised_2 = None, None
|
||||||
@@ -858,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return x
|
return x
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||||
@@ -867,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
|||||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return x
|
return x
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
@@ -876,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
|||||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return x
|
return x
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
||||||
@@ -1366,3 +1366,59 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
|
|||||||
x = x + d_bar * dt
|
x = x + d_bar * dt
|
||||||
old_d = d
|
old_d = d
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||||
|
"""
|
||||||
|
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
|
||||||
|
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||||
|
"""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
def default_noise_scaler(sigma):
|
||||||
|
return sigma * ((sigma ** 0.3).exp() + 10.0)
|
||||||
|
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
|
||||||
|
num_integration_points = 200.0
|
||||||
|
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
||||||
|
|
||||||
|
old_denoised = None
|
||||||
|
old_denoised_d = None
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
stage_used = min(max_stage, i + 1)
|
||||||
|
if sigmas[i + 1] == 0:
|
||||||
|
x = denoised
|
||||||
|
elif stage_used == 1:
|
||||||
|
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||||
|
x = r * x + (1 - r) * denoised
|
||||||
|
else:
|
||||||
|
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||||
|
x = r * x + (1 - r) * denoised
|
||||||
|
|
||||||
|
dt = sigmas[i + 1] - sigmas[i]
|
||||||
|
sigma_step_size = -dt / num_integration_points
|
||||||
|
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
|
||||||
|
scaled_pos = noise_scaler(sigma_pos)
|
||||||
|
|
||||||
|
# Stage 2
|
||||||
|
s = torch.sum(1 / scaled_pos) * sigma_step_size
|
||||||
|
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
|
||||||
|
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
|
||||||
|
|
||||||
|
if stage_used >= 3:
|
||||||
|
# Stage 3
|
||||||
|
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
|
||||||
|
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
|
||||||
|
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
|
||||||
|
old_denoised_d = denoised_d
|
||||||
|
|
||||||
|
if s_noise != 0 and sigmas[i + 1] > 0:
|
||||||
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||||
|
old_denoised = denoised
|
||||||
|
return x
|
||||||
|
|||||||
@@ -159,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
self.flipped_img_txt = flipped_img_txt
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
img_modulated = self.img_norm1(img)
|
img_modulated = self.img_norm1(img)
|
||||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims)
|
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||||
img_qkv = self.img_attn.qkv(img_modulated)
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
# prepare txt for attention
|
# prepare txt for attention
|
||||||
txt_modulated = self.txt_norm1(txt)
|
txt_modulated = self.txt_norm1(txt)
|
||||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims)
|
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
@@ -195,12 +195,12 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims)
|
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims)), img_mod2.gate, None, modulation_dims)
|
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims)
|
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims)), txt_mod2.gate, None, modulation_dims)
|
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
|||||||
@@ -10,10 +10,11 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
|||||||
q_shape = q.shape
|
q_shape = q.shape
|
||||||
k_shape = k.shape
|
k_shape = k.shape
|
||||||
|
|
||||||
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
|
if pe is not None:
|
||||||
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
|
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||||
|
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||||
|
|
||||||
heads = q.shape[1]
|
heads = q.shape[1]
|
||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||||
@@ -36,8 +37,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
|
|
||||||
|
|
||||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
|||||||
@@ -115,8 +115,11 @@ class Flux(nn.Module):
|
|||||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
if img_ids is not None:
|
||||||
pe = self.pe_embedder(ids)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
else:
|
||||||
|
pe = None
|
||||||
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.double_blocks):
|
for i, block in enumerate(self.double_blocks):
|
||||||
|
|||||||
@@ -244,9 +244,11 @@ class HunyuanVideo(nn.Module):
|
|||||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||||
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
||||||
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
||||||
|
modulation_dims_txt = [(0, None, 1)]
|
||||||
else:
|
else:
|
||||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
modulation_dims = None
|
modulation_dims = None
|
||||||
|
modulation_dims_txt = None
|
||||||
|
|
||||||
if self.params.guidance_embed:
|
if self.params.guidance_embed:
|
||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
@@ -273,14 +275,14 @@ class HunyuanVideo(nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_i = control.get("input")
|
control_i = control.get("input")
|
||||||
@@ -295,10 +297,10 @@ class HunyuanVideo(nn.Module):
|
|||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
||||||
|
|||||||
@@ -24,6 +24,13 @@ if model_management.sage_attention_enabled():
|
|||||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
|
if model_management.flash_attention_enabled():
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_func
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
@@ -496,6 +503,63 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||||
|
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||||
|
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
||||||
|
|
||||||
|
|
||||||
|
@flash_attn_wrapper.register_fake
|
||||||
|
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
|
||||||
|
# Output shape is the same as q
|
||||||
|
return q.new_empty(q.shape)
|
||||||
|
except AttributeError as error:
|
||||||
|
FLASH_ATTN_ERROR = error
|
||||||
|
|
||||||
|
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||||
|
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||||
|
|
||||||
|
|
||||||
|
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
if skip_reshape:
|
||||||
|
b, _, _, dim_head = q.shape
|
||||||
|
else:
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
# add a batch dimension if there isn't already one
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
# add a heads dimension if there isn't already one
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert mask is None
|
||||||
|
out = flash_attn_wrapper(
|
||||||
|
q.transpose(1, 2),
|
||||||
|
k.transpose(1, 2),
|
||||||
|
v.transpose(1, 2),
|
||||||
|
dropout_p=0.0,
|
||||||
|
causal=False,
|
||||||
|
).transpose(1, 2)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
|
if not skip_output_reshape:
|
||||||
|
out = (
|
||||||
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
|
||||||
if model_management.sage_attention_enabled():
|
if model_management.sage_attention_enabled():
|
||||||
@@ -504,6 +568,9 @@ if model_management.sage_attention_enabled():
|
|||||||
elif model_management.xformers_enabled():
|
elif model_management.xformers_enabled():
|
||||||
logging.info("Using xformers attention")
|
logging.info("Using xformers attention")
|
||||||
optimized_attention = attention_xformers
|
optimized_attention = attention_xformers
|
||||||
|
elif model_management.flash_attention_enabled():
|
||||||
|
logging.info("Using Flash Attention")
|
||||||
|
optimized_attention = attention_flash
|
||||||
elif model_management.pytorch_attention_enabled():
|
elif model_management.pytorch_attention_enabled():
|
||||||
logging.info("Using pytorch attention")
|
logging.info("Using pytorch attention")
|
||||||
optimized_attention = attention_pytorch
|
optimized_attention = attention_pytorch
|
||||||
|
|||||||
@@ -384,6 +384,7 @@ class WanModel(torch.nn.Module):
|
|||||||
context,
|
context,
|
||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
freqs=None,
|
freqs=None,
|
||||||
|
transformer_options={},
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Forward pass through the diffusion model
|
Forward pass through the diffusion model
|
||||||
@@ -423,14 +424,18 @@ class WanModel(torch.nn.Module):
|
|||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
context = torch.concat([context_clip, context], dim=1)
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
|
||||||
# arguments
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
kwargs = dict(
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
e=e0,
|
for i, block in enumerate(self.blocks):
|
||||||
freqs=freqs,
|
if ("double_block", i) in blocks_replace:
|
||||||
context=context)
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
for block in self.blocks:
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||||
x = block(x, **kwargs)
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context)
|
||||||
|
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
@@ -439,7 +444,7 @@ class WanModel(torch.nn.Module):
|
|||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
|
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
@@ -453,7 +458,7 @@ class WanModel(torch.nn.Module):
|
|||||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
|
|
||||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -973,11 +973,11 @@ class WAN21(BaseModel):
|
|||||||
self.image_to_video = image_to_video
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
def concat_cond(self, **kwargs):
|
def concat_cond(self, **kwargs):
|
||||||
if not self.image_to_video:
|
noise = kwargs.get("noise", None)
|
||||||
|
if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
image = kwargs.get("concat_latent_image", None)
|
image = kwargs.get("concat_latent_image", None)
|
||||||
noise = kwargs.get("noise", None)
|
|
||||||
device = kwargs["device"]
|
device = kwargs["device"]
|
||||||
|
|
||||||
if image is None:
|
if image is None:
|
||||||
@@ -987,6 +987,9 @@ class WAN21(BaseModel):
|
|||||||
image = self.process_latent_in(image)
|
image = self.process_latent_in(image)
|
||||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
|
|
||||||
|
if not self.image_to_video:
|
||||||
|
return image
|
||||||
|
|
||||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = torch.zeros_like(noise)[:, :4]
|
mask = torch.zeros_like(noise)[:, :4]
|
||||||
|
|||||||
@@ -186,12 +186,21 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
else:
|
else:
|
||||||
return mem_total
|
return mem_total
|
||||||
|
|
||||||
|
def mac_version():
|
||||||
|
try:
|
||||||
|
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info("pytorch version: {}".format(torch_version))
|
logging.info("pytorch version: {}".format(torch_version))
|
||||||
|
mac_ver = mac_version()
|
||||||
|
if mac_ver is not None:
|
||||||
|
logging.info("Mac Version {}".format(mac_ver))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -921,6 +930,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
def sage_attention_enabled():
|
def sage_attention_enabled():
|
||||||
return args.use_sage_attention
|
return args.use_sage_attention
|
||||||
|
|
||||||
|
def flash_attention_enabled():
|
||||||
|
return args.use_flash_attention
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
@@ -969,12 +981,6 @@ def pytorch_attention_flash_attention():
|
|||||||
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def mac_version():
|
|
||||||
try:
|
|
||||||
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def force_upcast_attention_dtype():
|
def force_upcast_attention_dtype():
|
||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
|
|
||||||
|
|||||||
@@ -747,6 +747,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def partially_unload(self, device_to, memory_to_free=0):
|
def partially_unload(self, device_to, memory_to_free=0):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
|
hooks_unpatched = False
|
||||||
memory_freed = 0
|
memory_freed = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
unload_list = self._load_list()
|
unload_list = self._load_list()
|
||||||
@@ -770,6 +771,10 @@ class ModelPatcher:
|
|||||||
move_weight = False
|
move_weight = False
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if not hooks_unpatched:
|
||||||
|
self.unpatch_hooks()
|
||||||
|
hooks_unpatched = True
|
||||||
|
|
||||||
if bk.inplace_update:
|
if bk.inplace_update:
|
||||||
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||||
else:
|
else:
|
||||||
@@ -1089,7 +1094,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
self.unpatch_hooks()
|
|
||||||
if hooks is not None:
|
if hooks is not None:
|
||||||
model_sd_keys = list(self.model_state_dict().keys())
|
model_sd_keys = list(self.model_state_dict().keys())
|
||||||
memory_counter = None
|
memory_counter = None
|
||||||
@@ -1100,12 +1104,16 @@ class ModelPatcher:
|
|||||||
# if have cached weights for hooks, use it
|
# if have cached weights for hooks, use it
|
||||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||||
if cached_weights is not None:
|
if cached_weights is not None:
|
||||||
|
model_sd_keys_set = set(model_sd_keys)
|
||||||
for key in cached_weights:
|
for key in cached_weights:
|
||||||
if key not in model_sd_keys:
|
if key not in model_sd_keys:
|
||||||
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
||||||
continue
|
continue
|
||||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||||
|
model_sd_keys_set.remove(key)
|
||||||
|
self.unpatch_hooks(model_sd_keys_set)
|
||||||
else:
|
else:
|
||||||
|
self.unpatch_hooks()
|
||||||
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
||||||
original_weights = None
|
original_weights = None
|
||||||
if len(relevant_patches) > 0:
|
if len(relevant_patches) > 0:
|
||||||
@@ -1116,6 +1124,8 @@ class ModelPatcher:
|
|||||||
continue
|
continue
|
||||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||||
memory_counter=memory_counter)
|
memory_counter=memory_counter)
|
||||||
|
else:
|
||||||
|
self.unpatch_hooks()
|
||||||
self.current_hooks = hooks
|
self.current_hooks = hooks
|
||||||
|
|
||||||
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||||
@@ -1172,17 +1182,23 @@ class ModelPatcher:
|
|||||||
del out_weight
|
del out_weight
|
||||||
del weight
|
del weight
|
||||||
|
|
||||||
def unpatch_hooks(self) -> None:
|
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
if len(self.hook_backup) == 0:
|
if len(self.hook_backup) == 0:
|
||||||
self.current_hooks = None
|
self.current_hooks = None
|
||||||
return
|
return
|
||||||
keys = list(self.hook_backup.keys())
|
keys = list(self.hook_backup.keys())
|
||||||
for k in keys:
|
if whitelist_keys_set:
|
||||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
for k in keys:
|
||||||
|
if k in whitelist_keys_set:
|
||||||
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
|
self.hook_backup.pop(k)
|
||||||
|
else:
|
||||||
|
for k in keys:
|
||||||
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
|
|
||||||
self.hook_backup.clear()
|
self.hook_backup.clear()
|
||||||
self.current_hooks = None
|
self.current_hooks = None
|
||||||
|
|
||||||
def clean_hooks(self):
|
def clean_hooks(self):
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
|
|||||||
@@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
|||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
"gradient_estimation"]
|
"gradient_estimation", "er_sde"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
|||||||
15
comfy/sd.py
15
comfy/sd.py
@@ -440,6 +440,10 @@ class VAE:
|
|||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||||
|
|
||||||
|
def throw_exception_if_invalid(self):
|
||||||
|
if self.first_stage_model is None:
|
||||||
|
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||||
|
|
||||||
def vae_encode_crop_pixels(self, pixels):
|
def vae_encode_crop_pixels(self, pixels):
|
||||||
downscale_ratio = self.spacial_compression_encode()
|
downscale_ratio = self.spacial_compression_encode()
|
||||||
|
|
||||||
@@ -495,6 +499,7 @@ class VAE:
|
|||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = None
|
pixel_samples = None
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
@@ -525,6 +530,7 @@ class VAE:
|
|||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
|
self.throw_exception_if_invalid()
|
||||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
dims = samples.ndim - 2
|
dims = samples.ndim - 2
|
||||||
@@ -553,6 +559,7 @@ class VAE:
|
|||||||
return output.movedim(1, -1)
|
return output.movedim(1, -1)
|
||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||||
@@ -585,6 +592,7 @@ class VAE:
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
dims = self.latent_dim
|
dims = self.latent_dim
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
@@ -899,7 +907,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
return None
|
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||||
|
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
|
||||||
|
if diffusion_model is None:
|
||||||
|
return None
|
||||||
|
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||||
|
|
||||||
|
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if model_config.scaled_fp8 is not None:
|
if model_config.scaled_fp8 is not None:
|
||||||
|
|||||||
@@ -19,8 +19,6 @@ class Load3D():
|
|||||||
"image": ("LOAD_3D", {}),
|
"image": ("LOAD_3D", {}),
|
||||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
|
||||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||||
@@ -55,8 +53,6 @@ class Load3DAnimation():
|
|||||||
"image": ("LOAD_3D_ANIMATION", {}),
|
"image": ("LOAD_3D_ANIMATION", {}),
|
||||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
|
||||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||||
@@ -82,8 +78,6 @@ class Preview3D():
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
|
||||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
@@ -102,8 +96,6 @@ class Preview3DAnimation():
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
|
||||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|||||||
@@ -99,12 +99,13 @@ class LTXVAddGuide:
|
|||||||
"negative": ("CONDITIONING", ),
|
"negative": ("CONDITIONING", ),
|
||||||
"vae": ("VAE",),
|
"vae": ("VAE",),
|
||||||
"latent": ("LATENT",),
|
"latent": ("LATENT",),
|
||||||
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \
|
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames."
|
||||||
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
|
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
|
||||||
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
|
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
|
||||||
"tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \
|
"tooltip": "Frame index to start the conditioning at. For single-frame images or "
|
||||||
"If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \
|
"videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
|
||||||
"Negative values are counted from the end of the video."}),
|
"frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
|
||||||
|
"the nearest multiple of 8. Negative values are counted from the end of the video."}),
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -127,12 +128,13 @@ class LTXVAddGuide:
|
|||||||
t = vae.encode(encode_pixels)
|
t = vae.encode(encode_pixels)
|
||||||
return encode_pixels, t
|
return encode_pixels, t
|
||||||
|
|
||||||
def get_latent_index(self, cond, latent_length, frame_idx, scale_factors):
|
def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors):
|
||||||
time_scale_factor, _, _ = scale_factors
|
time_scale_factor, _, _ = scale_factors
|
||||||
_, num_keyframes = get_keyframe_idxs(cond)
|
_, num_keyframes = get_keyframe_idxs(cond)
|
||||||
latent_count = latent_length - num_keyframes
|
latent_count = latent_length - num_keyframes
|
||||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0)
|
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||||
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
|
if guide_length > 1:
|
||||||
|
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
|
||||||
|
|
||||||
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
|
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
|
||||||
|
|
||||||
@@ -191,7 +193,7 @@ class LTXVAddGuide:
|
|||||||
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
||||||
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
|
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
|
||||||
|
|
||||||
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors)
|
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||||
|
|
||||||
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
|
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
|
||||||
|
|||||||
@@ -28,15 +28,12 @@ class SaveWEBM:
|
|||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
RETURN_TYPES = ()
|
||||||
FUNCTION = "save_images"
|
FUNCTION = "save_video"
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
CATEGORY = "video"
|
||||||
CATEGORY = "image/video"
|
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None):
|
def save_video(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None):
|
||||||
filename_prefix += self.prefix_append
|
filename_prefix += self.prefix_append
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||||
|
|
||||||
@@ -71,7 +68,7 @@ class SaveWEBM:
|
|||||||
"type": self.type
|
"type": self.type
|
||||||
}]
|
}]
|
||||||
|
|
||||||
return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side
|
return {"ui": {"video": results}}
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.25"
|
__version__ = "0.3.26"
|
||||||
|
|||||||
@@ -634,6 +634,13 @@ def validate_inputs(prompt, item, validated):
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
# Unwraps values wrapped in __value__ key. This is used to pass
|
||||||
|
# list widget value to execution, as by default list value is
|
||||||
|
# reserved to represent the connection between nodes.
|
||||||
|
if isinstance(val, dict) and "__value__" in val:
|
||||||
|
val = val["__value__"]
|
||||||
|
inputs[x] = val
|
||||||
|
|
||||||
if type_input == "INT":
|
if type_input == "INT":
|
||||||
val = int(val)
|
val = int(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
|
|||||||
18
main.py
18
main.py
@@ -139,7 +139,7 @@ from server import BinaryEventTypes
|
|||||||
import nodes
|
import nodes
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfyui_version
|
import comfyui_version
|
||||||
import app.frontend_management
|
import app.logger
|
||||||
|
|
||||||
|
|
||||||
def cuda_malloc_warning():
|
def cuda_malloc_warning():
|
||||||
@@ -293,28 +293,14 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
return asyncio_loop, prompt_server, start_all
|
return asyncio_loop, prompt_server, start_all
|
||||||
|
|
||||||
|
|
||||||
def warn_frontend_version(frontend_version):
|
|
||||||
try:
|
|
||||||
required_frontend = (0,)
|
|
||||||
req_path = os.path.join(os.path.dirname(__file__), 'requirements.txt')
|
|
||||||
with open(req_path, 'r') as f:
|
|
||||||
required_frontend = tuple(map(int, f.readline().split('=')[-1].split('.')))
|
|
||||||
if frontend_version < required_frontend:
|
|
||||||
logging.warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), app.frontend_management.frontend_install_warning_message()))
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Running directly, just start ComfyUI.
|
# Running directly, just start ComfyUI.
|
||||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||||
frontend_version = app.frontend_management.frontend_version
|
|
||||||
logging.info("ComfyUI frontend version: {}".format('.'.join(map(str, frontend_version))))
|
|
||||||
|
|
||||||
event_loop, _, start_all_func = start_comfyui()
|
event_loop, _, start_all_func = start_comfyui()
|
||||||
try:
|
try:
|
||||||
x = start_all_func()
|
x = start_all_func()
|
||||||
warn_frontend_version(frontend_version)
|
app.logger.print_startup_warnings()
|
||||||
event_loop.run_until_complete(x)
|
event_loop.run_until_complete(x)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging.info("\nStopped server")
|
logging.info("\nStopped server")
|
||||||
|
|||||||
12
nodes.py
12
nodes.py
@@ -489,7 +489,7 @@ class SaveLatent:
|
|||||||
file = os.path.join(full_output_folder, file)
|
file = os.path.join(full_output_folder, file)
|
||||||
|
|
||||||
output = {}
|
output = {}
|
||||||
output["latent_tensor"] = samples["samples"]
|
output["latent_tensor"] = samples["samples"].contiguous()
|
||||||
output["latent_format_version_0"] = torch.tensor([])
|
output["latent_format_version_0"] = torch.tensor([])
|
||||||
|
|
||||||
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
||||||
@@ -770,6 +770,7 @@ class VAELoader:
|
|||||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
sd = comfy.utils.load_torch_file(vae_path)
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
vae = comfy.sd.VAE(sd=sd)
|
vae = comfy.sd.VAE(sd=sd)
|
||||||
|
vae.throw_exception_if_invalid()
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
||||||
class ControlNetLoader:
|
class ControlNetLoader:
|
||||||
@@ -1785,14 +1786,7 @@ class LoadImageOutput(LoadImage):
|
|||||||
|
|
||||||
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
|
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
FUNCTION = "load_image_output"
|
FUNCTION = "load_image"
|
||||||
|
|
||||||
def load_image_output(self, image):
|
|
||||||
return self.load_image(f"{image} [output]")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def VALIDATE_INPUTS(s, image):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class ImageScale:
|
class ImageScale:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.25"
|
version = "0.3.26"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.11.8
|
comfyui-frontend-package==1.12.14
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ def test_get_release_invalid_version(mock_provider):
|
|||||||
def test_init_frontend_default():
|
def test_init_frontend_default():
|
||||||
version_string = DEFAULT_VERSION_STRING
|
version_string = DEFAULT_VERSION_STRING
|
||||||
frontend_path = FrontendManager.init_frontend(version_string)
|
frontend_path = FrontendManager.init_frontend(version_string)
|
||||||
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
|
assert frontend_path == FrontendManager.default_frontend_path()
|
||||||
|
|
||||||
|
|
||||||
def test_init_frontend_invalid_version():
|
def test_init_frontend_invalid_version():
|
||||||
@@ -84,24 +84,29 @@ def test_init_frontend_invalid_provider():
|
|||||||
with pytest.raises(HTTPError):
|
with pytest.raises(HTTPError):
|
||||||
FrontendManager.init_frontend_unsafe(version_string)
|
FrontendManager.init_frontend_unsafe(version_string)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_os_functions():
|
def mock_os_functions():
|
||||||
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
|
with (
|
||||||
patch('app.frontend_management.os.listdir') as mock_listdir, \
|
patch("app.frontend_management.os.makedirs") as mock_makedirs,
|
||||||
patch('app.frontend_management.os.rmdir') as mock_rmdir:
|
patch("app.frontend_management.os.listdir") as mock_listdir,
|
||||||
|
patch("app.frontend_management.os.rmdir") as mock_rmdir,
|
||||||
|
):
|
||||||
mock_listdir.return_value = [] # Simulate empty directory
|
mock_listdir.return_value = [] # Simulate empty directory
|
||||||
yield mock_makedirs, mock_listdir, mock_rmdir
|
yield mock_makedirs, mock_listdir, mock_rmdir
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_download():
|
def mock_download():
|
||||||
with patch('app.frontend_management.download_release_asset_zip') as mock:
|
with patch("app.frontend_management.download_release_asset_zip") as mock:
|
||||||
mock.side_effect = Exception("Download failed") # Simulate download failure
|
mock.side_effect = Exception("Download failed") # Simulate download failure
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
def test_finally_block(mock_os_functions, mock_download, mock_provider):
|
def test_finally_block(mock_os_functions, mock_download, mock_provider):
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
|
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
|
||||||
version_string = 'test-owner/test-repo@1.0.0'
|
version_string = "test-owner/test-repo@1.0.0"
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
@@ -128,3 +133,42 @@ def test_parse_version_string_invalid():
|
|||||||
version_string = "invalid"
|
version_string = "invalid"
|
||||||
with pytest.raises(argparse.ArgumentTypeError):
|
with pytest.raises(argparse.ArgumentTypeError):
|
||||||
FrontendManager.parse_version_string(version_string)
|
FrontendManager.parse_version_string(version_string)
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_frontend_default_with_mocks():
|
||||||
|
# Arrange
|
||||||
|
version_string = DEFAULT_VERSION_STRING
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with (
|
||||||
|
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||||
|
patch.object(
|
||||||
|
FrontendManager, "default_frontend_path", return_value="/mocked/path"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
frontend_path = FrontendManager.init_frontend(version_string)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert frontend_path == "/mocked/path"
|
||||||
|
mock_check.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_frontend_fallback_on_error():
|
||||||
|
# Arrange
|
||||||
|
version_string = "test-owner/test-repo@1.0.0"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
FrontendManager, "init_frontend_unsafe", side_effect=Exception("Test error")
|
||||||
|
),
|
||||||
|
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||||
|
patch.object(
|
||||||
|
FrontendManager, "default_frontend_path", return_value="/default/path"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
frontend_path = FrontendManager.init_frontend(version_string)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert frontend_path == "/default/path"
|
||||||
|
mock_check.assert_called_once()
|
||||||
|
|||||||
Reference in New Issue
Block a user