Compare commits
85 Commits
yo-add-pre
...
video_outp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60b459bb4c | ||
|
|
3b19fc76e3 | ||
|
|
50614f1b79 | ||
|
|
6dc7b0bfe3 | ||
|
|
e8e990d6b8 | ||
|
|
2e24a15905 | ||
|
|
fd5297131f | ||
|
|
55a1b09ddc | ||
|
|
3c3988df45 | ||
|
|
7ebd8087ff | ||
|
|
c624c29d66 | ||
|
|
a2448fc527 | ||
|
|
6a0daa79b6 | ||
|
|
9c98c6358b | ||
|
|
7aceb9f91c | ||
|
|
35504e2f93 | ||
|
|
299436cfed | ||
|
|
52e566d2bc | ||
|
|
9b6cd9b874 | ||
|
|
3fc688aebd | ||
|
|
f4411250f3 | ||
|
|
d2a0fb6bb0 | ||
|
|
01015bff16 | ||
|
|
2330754b0e | ||
|
|
bc219a6487 | ||
|
|
94689766ad | ||
|
|
cfbe4b49ca | ||
|
|
ca8efab79f | ||
|
|
65ea778a5e | ||
|
|
db9f2a34fc | ||
|
|
7946049794 | ||
|
|
6f6349b6a7 | ||
|
|
1f138dd382 | ||
|
|
b779349b55 | ||
|
|
35e2dcf5d7 | ||
|
|
67c7184b74 | ||
|
|
6f8e766509 | ||
|
|
e1da98a14a | ||
|
|
a73410aafa | ||
|
|
9aac21f894 | ||
|
|
528d1b3563 | ||
|
|
2bc4b5968f | ||
|
|
7395b0c0d1 | ||
|
|
0952569493 | ||
|
|
29832b3b61 | ||
|
|
be4e760648 | ||
|
|
c3d9cc4592 | ||
|
|
84cc9cb528 | ||
|
|
ebbb920163 | ||
|
|
d60fe0af4a | ||
|
|
5dbd250965 | ||
|
|
4ab1875283 | ||
|
|
11b1f27cb1 | ||
|
|
70e15fd743 | ||
|
|
e1474150de | ||
|
|
e62d72e8ca | ||
|
|
1650cda030 | ||
|
|
a13125840c | ||
|
|
dfa36e6855 | ||
|
|
0124be4d93 | ||
|
|
29a70ca101 | ||
|
|
0bef826a98 | ||
|
|
85ef295069 | ||
|
|
5d84607bf3 | ||
|
|
c1909f350f | ||
|
|
52b3469606 | ||
|
|
889519971f | ||
|
|
76739c23c3 | ||
|
|
a80bc822a2 | ||
|
|
872780d236 | ||
|
|
6d45ffbe23 | ||
|
|
77633ba77d | ||
|
|
30e6cfb1a0 | ||
|
|
dc134b2fdb | ||
|
|
369b079ff6 | ||
|
|
9c9a7f012a | ||
|
|
93fedd92fe | ||
|
|
745b13649b | ||
|
|
2b140654c7 | ||
|
|
65042f7d39 | ||
|
|
7c7c70c400 | ||
|
|
8362199ee7 | ||
|
|
f86c724ef2 | ||
|
|
d6e5d487ad | ||
|
|
6752a826f6 |
@@ -0,0 +1,2 @@
|
|||||||
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||||
|
pause
|
||||||
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@@ -22,7 +22,7 @@ on:
|
|||||||
description: 'Python patch version'
|
description: 'Python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "8"
|
default: "9"
|
||||||
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ on:
|
|||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "8"
|
default: "9"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ on:
|
|||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "126"
|
default: "128"
|
||||||
|
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
@@ -19,7 +19,7 @@ on:
|
|||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "1"
|
default: "2"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
@@ -34,7 +34,7 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 30
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
@@ -74,7 +74,7 @@ jobs:
|
|||||||
pause" > ./update/update_comfyui_and_python_dependencies.bat
|
pause" > ./update/update_comfyui_and_python_dependencies.bat
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||||
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
|
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
|
||||||
|
|
||||||
cd ComfyUI_windows_portable_nightly_pytorch
|
cd ComfyUI_windows_portable_nightly_pytorch
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ on:
|
|||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "8"
|
default: "9"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
@@ -19,5 +19,6 @@
|
|||||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
|
|
||||||
# Extra nodes
|
# Node developers
|
||||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
|
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||||
|
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||||
|
|||||||
@@ -215,9 +215,9 @@ Nvidia users should install stable pytorch using this command:
|
|||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
|
||||||
|
|
||||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements.
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
||||||
|
|
||||||
#### Troubleshooting
|
#### Troubleshooting
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import zipfile
|
import zipfile
|
||||||
import importlib
|
import importlib
|
||||||
@@ -10,19 +11,43 @@ from dataclasses import dataclass
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypedDict, Optional
|
from typing import TypedDict, Optional
|
||||||
|
from importlib.metadata import version
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||||
|
import app.logger
|
||||||
|
|
||||||
|
# The path to the requirements.txt file
|
||||||
|
req_path = Path(__file__).parents[1] / "requirements.txt"
|
||||||
|
|
||||||
|
def frontend_install_warning_message():
|
||||||
|
"""The warning message to display when the frontend version is not up to date."""
|
||||||
|
|
||||||
|
extra = ""
|
||||||
|
if sys.flags.no_user_site:
|
||||||
|
extra = "-s "
|
||||||
|
return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem"
|
||||||
|
|
||||||
|
|
||||||
try:
|
def check_frontend_version():
|
||||||
import comfyui_frontend_package
|
"""Check if the frontend version is up to date."""
|
||||||
except ImportError as e:
|
|
||||||
# TODO: Remove the check after roll out of 0.3.16
|
def parse_version(version: str) -> tuple[int, int, int]:
|
||||||
logging.error("comfyui-frontend-package is not installed. Please install the updated requirements.txt file by running: pip install -r requirements.txt")
|
return tuple(map(int, version.split(".")))
|
||||||
raise e
|
|
||||||
|
try:
|
||||||
|
frontend_version_str = version("comfyui-frontend-package")
|
||||||
|
frontend_version = parse_version(frontend_version_str)
|
||||||
|
with open(req_path, "r", encoding="utf-8") as f:
|
||||||
|
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||||
|
if frontend_version < required_frontend:
|
||||||
|
app.logger.log_startup_warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message()))
|
||||||
|
else:
|
||||||
|
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to check frontend version: {e}")
|
||||||
|
|
||||||
|
|
||||||
REQUEST_TIMEOUT = 10 # seconds
|
REQUEST_TIMEOUT = 10 # seconds
|
||||||
@@ -119,9 +144,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
class FrontendManager:
|
class FrontendManager:
|
||||||
DEFAULT_FRONTEND_PATH = str(importlib.resources.files(comfyui_frontend_package) / "static")
|
|
||||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default_frontend_path(cls) -> str:
|
||||||
|
try:
|
||||||
|
import comfyui_frontend_package
|
||||||
|
return str(importlib.resources.files(comfyui_frontend_package) / "static")
|
||||||
|
except ImportError:
|
||||||
|
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n")
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
@@ -158,7 +191,8 @@ class FrontendManager:
|
|||||||
main error source might be request timeout or invalid URL.
|
main error source might be request timeout or invalid URL.
|
||||||
"""
|
"""
|
||||||
if version_string == DEFAULT_VERSION_STRING:
|
if version_string == DEFAULT_VERSION_STRING:
|
||||||
return cls.DEFAULT_FRONTEND_PATH
|
check_frontend_version()
|
||||||
|
return cls.default_frontend_path()
|
||||||
|
|
||||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||||
|
|
||||||
@@ -211,4 +245,5 @@ class FrontendManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Failed to initialize frontend: %s", e)
|
logging.error("Failed to initialize frontend: %s", e)
|
||||||
logging.info("Falling back to the default frontend.")
|
logging.info("Falling back to the default frontend.")
|
||||||
return cls.DEFAULT_FRONTEND_PATH
|
check_frontend_version()
|
||||||
|
return cls.default_frontend_path()
|
||||||
|
|||||||
@@ -82,3 +82,17 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
|
|||||||
logger.addHandler(stdout_handler)
|
logger.addHandler(stdout_handler)
|
||||||
|
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
|
||||||
|
STARTUP_WARNINGS = []
|
||||||
|
|
||||||
|
|
||||||
|
def log_startup_warning(msg):
|
||||||
|
logging.warning(msg)
|
||||||
|
STARTUP_WARNINGS.append(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def print_startup_warnings():
|
||||||
|
for s in STARTUP_WARNINGS:
|
||||||
|
logging.warning(s)
|
||||||
|
STARTUP_WARNINGS.clear()
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import enum
|
import enum
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
|
||||||
import comfy.options
|
import comfy.options
|
||||||
|
|
||||||
|
|
||||||
@@ -107,6 +106,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
|
|||||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||||
|
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|
||||||
@@ -166,13 +166,14 @@ parser.add_argument(
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_valid_directory(path: Optional[str]) -> Optional[str]:
|
def is_valid_directory(path: str) -> str:
|
||||||
"""Validate if the given path is a directory."""
|
"""Validate if the given path is a directory, and check permissions."""
|
||||||
if path is None:
|
if not os.path.exists(path):
|
||||||
return None
|
raise argparse.ArgumentTypeError(f"The path '{path}' does not exist.")
|
||||||
|
|
||||||
if not os.path.isdir(path):
|
if not os.path.isdir(path):
|
||||||
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
|
raise argparse.ArgumentTypeError(f"'{path}' is not a directory.")
|
||||||
|
if not os.access(path, os.R_OK):
|
||||||
|
raise argparse.ArgumentTypeError(f"You do not have read permissions for '{path}'.")
|
||||||
return path
|
return path
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -97,8 +97,12 @@ class CLIPTextModel_(torch.nn.Module):
|
|||||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
||||||
x = self.embeddings(input_tokens, dtype=dtype)
|
if embeds is not None:
|
||||||
|
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
|
||||||
|
else:
|
||||||
|
x = self.embeddings(input_tokens, dtype=dtype)
|
||||||
|
|
||||||
mask = None
|
mask = None
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||||
@@ -116,7 +120,10 @@ class CLIPTextModel_(torch.nn.Module):
|
|||||||
if i is not None and final_layer_norm_intermediate:
|
if i is not None and final_layer_norm_intermediate:
|
||||||
i = self.final_layer_norm(i)
|
i = self.final_layer_norm(i)
|
||||||
|
|
||||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
|
if num_tokens is not None:
|
||||||
|
pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
|
||||||
|
else:
|
||||||
|
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
|
||||||
return x, i, pooled_output
|
return x, i, pooled_output
|
||||||
|
|
||||||
class CLIPTextModel(torch.nn.Module):
|
class CLIPTextModel(torch.nn.Module):
|
||||||
@@ -204,6 +211,15 @@ class CLIPVision(torch.nn.Module):
|
|||||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||||
return x, i, pooled_output
|
return x, i, pooled_output
|
||||||
|
|
||||||
|
class LlavaProjector(torch.nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
|
||||||
|
self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
|
||||||
|
|
||||||
class CLIPVisionModelProjection(torch.nn.Module):
|
class CLIPVisionModelProjection(torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -213,7 +229,16 @@ class CLIPVisionModelProjection(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.visual_projection = lambda a: a
|
self.visual_projection = lambda a: a
|
||||||
|
|
||||||
|
if "llava3" == config_dict.get("projector_type", None):
|
||||||
|
self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
|
||||||
|
else:
|
||||||
|
self.multi_modal_projector = None
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
x = self.vision_model(*args, **kwargs)
|
x = self.vision_model(*args, **kwargs)
|
||||||
out = self.visual_projection(x[2])
|
out = self.visual_projection(x[2])
|
||||||
return (x[0], x[1], out)
|
projected = None
|
||||||
|
if self.multi_modal_projector is not None:
|
||||||
|
projected = self.multi_modal_projector(x[1])
|
||||||
|
|
||||||
|
return (x[0], x[1], out, projected)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import comfy.model_patcher
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.clip_model
|
import comfy.clip_model
|
||||||
|
import comfy.image_encoders.dino2
|
||||||
|
|
||||||
class Output:
|
class Output:
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
@@ -34,6 +35,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
|
|||||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||||
|
|
||||||
|
IMAGE_ENCODERS = {
|
||||||
|
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
|
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
|
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||||
|
}
|
||||||
|
|
||||||
class ClipVisionModel():
|
class ClipVisionModel():
|
||||||
def __init__(self, json_config):
|
def __init__(self, json_config):
|
||||||
with open(json_config) as f:
|
with open(json_config) as f:
|
||||||
@@ -42,10 +49,11 @@ class ClipVisionModel():
|
|||||||
self.image_size = config.get("image_size", 224)
|
self.image_size = config.get("image_size", 224)
|
||||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||||
|
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
|
||||||
self.load_device = comfy.model_management.text_encoder_device()
|
self.load_device = comfy.model_management.text_encoder_device()
|
||||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||||
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
@@ -65,6 +73,7 @@ class ClipVisionModel():
|
|||||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||||
|
outputs["mm_projected"] = out[3]
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def convert_to_transformers(sd, prefix):
|
def convert_to_transformers(sd, prefix):
|
||||||
@@ -104,9 +113,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||||
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
if "multi_modal_projector.linear_1.bias" in sd:
|
||||||
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
||||||
|
else:
|
||||||
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||||
else:
|
else:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||||
|
elif "embeddings.patch_embeddings.projection.weight" in sd:
|
||||||
|
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
19
comfy/clip_vision_config_vitl_336_llava.json
Normal file
19
comfy/clip_vision_config_vitl_336_llava.json
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
{
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"dropout": 0.0,
|
||||||
|
"hidden_act": "quick_gelu",
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"image_size": 336,
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 4096,
|
||||||
|
"layer_norm_eps": 1e-5,
|
||||||
|
"model_type": "clip_vision_model",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_hidden_layers": 24,
|
||||||
|
"patch_size": 14,
|
||||||
|
"projection_dim": 768,
|
||||||
|
"projector_type": "llava3",
|
||||||
|
"torch_dtype": "float32"
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import Callable, Protocol, TypedDict, Optional, List
|
from typing import Callable, Protocol, TypedDict, Optional, List
|
||||||
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
|
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin, FileLocator
|
||||||
|
|
||||||
|
|
||||||
class UnetApplyFunction(Protocol):
|
class UnetApplyFunction(Protocol):
|
||||||
@@ -42,4 +42,5 @@ __all__ = [
|
|||||||
InputTypeDict.__name__,
|
InputTypeDict.__name__,
|
||||||
ComfyNodeABC.__name__,
|
ComfyNodeABC.__name__,
|
||||||
CheckLazyMixin.__name__,
|
CheckLazyMixin.__name__,
|
||||||
|
FileLocator.__name__,
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Literal, TypedDict
|
from typing import Literal, TypedDict
|
||||||
|
from typing_extensions import NotRequired
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@@ -26,6 +27,7 @@ class IO(StrEnum):
|
|||||||
BOOLEAN = "BOOLEAN"
|
BOOLEAN = "BOOLEAN"
|
||||||
INT = "INT"
|
INT = "INT"
|
||||||
FLOAT = "FLOAT"
|
FLOAT = "FLOAT"
|
||||||
|
COMBO = "COMBO"
|
||||||
CONDITIONING = "CONDITIONING"
|
CONDITIONING = "CONDITIONING"
|
||||||
SAMPLER = "SAMPLER"
|
SAMPLER = "SAMPLER"
|
||||||
SIGMAS = "SIGMAS"
|
SIGMAS = "SIGMAS"
|
||||||
@@ -66,6 +68,7 @@ class IO(StrEnum):
|
|||||||
b = frozenset(value.split(","))
|
b = frozenset(value.split(","))
|
||||||
return not (b.issubset(a) or a.issubset(b))
|
return not (b.issubset(a) or a.issubset(b))
|
||||||
|
|
||||||
|
|
||||||
class RemoteInputOptions(TypedDict):
|
class RemoteInputOptions(TypedDict):
|
||||||
route: str
|
route: str
|
||||||
"""The route to the remote source."""
|
"""The route to the remote source."""
|
||||||
@@ -80,6 +83,14 @@ class RemoteInputOptions(TypedDict):
|
|||||||
refresh: int
|
refresh: int
|
||||||
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
|
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
|
||||||
|
|
||||||
|
|
||||||
|
class MultiSelectOptions(TypedDict):
|
||||||
|
placeholder: NotRequired[str]
|
||||||
|
"""The placeholder text to display in the multi-select widget when no items are selected."""
|
||||||
|
chip: NotRequired[bool]
|
||||||
|
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
|
||||||
|
|
||||||
|
|
||||||
class InputTypeOptions(TypedDict):
|
class InputTypeOptions(TypedDict):
|
||||||
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
||||||
|
|
||||||
@@ -114,7 +125,7 @@ class InputTypeOptions(TypedDict):
|
|||||||
# default: bool
|
# default: bool
|
||||||
label_on: str
|
label_on: str
|
||||||
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
||||||
label_on: str
|
label_off: str
|
||||||
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
||||||
# class InputTypeString(InputTypeOptions):
|
# class InputTypeString(InputTypeOptions):
|
||||||
# default: str
|
# default: str
|
||||||
@@ -133,7 +144,22 @@ class InputTypeOptions(TypedDict):
|
|||||||
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
||||||
"""
|
"""
|
||||||
remote: RemoteInputOptions
|
remote: RemoteInputOptions
|
||||||
"""Specifies the configuration for a remote input."""
|
"""Specifies the configuration for a remote input.
|
||||||
|
Available after ComfyUI frontend v1.9.7
|
||||||
|
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
|
||||||
|
control_after_generate: bool
|
||||||
|
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
|
||||||
|
options: NotRequired[list[str | int | float]]
|
||||||
|
"""COMBO type only. Specifies the selectable options for the combo widget.
|
||||||
|
Prefer:
|
||||||
|
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
|
||||||
|
Over:
|
||||||
|
[["Option 1", "Option 2", "Option 3"]]
|
||||||
|
"""
|
||||||
|
multi_select: NotRequired[MultiSelectOptions]
|
||||||
|
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||||
|
Available after ComfyUI frontend v1.13.4
|
||||||
|
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||||
|
|
||||||
|
|
||||||
class HiddenInputTypeDict(TypedDict):
|
class HiddenInputTypeDict(TypedDict):
|
||||||
@@ -293,3 +319,14 @@ class CheckLazyMixin:
|
|||||||
|
|
||||||
need = [name for name in kwargs if kwargs[name] is None]
|
need = [name for name in kwargs if kwargs[name] is None]
|
||||||
return need
|
return need
|
||||||
|
|
||||||
|
|
||||||
|
class FileLocator(TypedDict):
|
||||||
|
"""Provides type hinting for the file location"""
|
||||||
|
|
||||||
|
filename: str
|
||||||
|
"""The filename of the file."""
|
||||||
|
subfolder: str
|
||||||
|
"""The subfolder of the file."""
|
||||||
|
type: Literal["input", "output", "temp"]
|
||||||
|
"""The root folder of the file."""
|
||||||
|
|||||||
141
comfy/image_encoders/dino2.py
Normal file
141
comfy/image_encoders/dino2.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
import torch
|
||||||
|
from comfy.text_encoders.bert import BertAttention
|
||||||
|
import comfy.model_management
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2AttentionOutput(torch.nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.dense(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2AttentionBlock(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
|
||||||
|
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
||||||
|
|
||||||
|
def forward(self, x, mask, optimized_attention):
|
||||||
|
return self.output(self.attention(x, mask, optimized_attention))
|
||||||
|
|
||||||
|
|
||||||
|
class LayerScale(torch.nn.Module):
|
||||||
|
def __init__(self, dim, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLUFFN(torch.nn.Module):
|
||||||
|
def __init__(self, dim, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
in_features = out_features = dim
|
||||||
|
hidden_features = int(dim * 4)
|
||||||
|
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||||
|
|
||||||
|
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
|
||||||
|
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.weights_in(x)
|
||||||
|
x1, x2 = x.chunk(2, dim=-1)
|
||||||
|
x = torch.nn.functional.silu(x1) * x2
|
||||||
|
return self.weights_out(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2Block(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||||
|
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||||
|
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||||
|
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||||
|
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x, optimized_attention):
|
||||||
|
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
|
||||||
|
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2Encoder(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
||||||
|
|
||||||
|
def forward(self, x, intermediate_output=None):
|
||||||
|
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||||
|
|
||||||
|
if intermediate_output is not None:
|
||||||
|
if intermediate_output < 0:
|
||||||
|
intermediate_output = len(self.layer) + intermediate_output
|
||||||
|
|
||||||
|
intermediate = None
|
||||||
|
for i, l in enumerate(self.layer):
|
||||||
|
x = l(x, optimized_attention)
|
||||||
|
if i == intermediate_output:
|
||||||
|
intermediate = x.clone()
|
||||||
|
return x, intermediate
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2PatchEmbeddings(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.projection = operations.Conv2d(
|
||||||
|
in_channels=num_channels,
|
||||||
|
out_channels=dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, pixel_values):
|
||||||
|
return self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class Dino2Embeddings(torch.nn.Module):
|
||||||
|
def __init__(self, dim, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
patch_size = 14
|
||||||
|
image_size = 518
|
||||||
|
|
||||||
|
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
||||||
|
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
|
||||||
|
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
def forward(self, pixel_values):
|
||||||
|
x = self.patch_embeddings(pixel_values)
|
||||||
|
# TODO: mask_token?
|
||||||
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
|
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Dinov2Model(torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
num_layers = config_dict["num_hidden_layers"]
|
||||||
|
dim = config_dict["hidden_size"]
|
||||||
|
heads = config_dict["num_attention_heads"]
|
||||||
|
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||||
|
|
||||||
|
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||||
|
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
||||||
|
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||||
|
x = self.embeddings(pixel_values)
|
||||||
|
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
||||||
|
x = self.layernorm(x)
|
||||||
|
pooled_output = x[:, 0, :]
|
||||||
|
return x, i, pooled_output, None
|
||||||
21
comfy/image_encoders/dino2_giant.json
Normal file
21
comfy/image_encoders/dino2_giant.json
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"attention_probs_dropout_prob": 0.0,
|
||||||
|
"drop_path_rate": 0.0,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout_prob": 0.0,
|
||||||
|
"hidden_size": 1536,
|
||||||
|
"image_size": 518,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"layer_norm_eps": 1e-06,
|
||||||
|
"layerscale_value": 1.0,
|
||||||
|
"mlp_ratio": 4,
|
||||||
|
"model_type": "dinov2",
|
||||||
|
"num_attention_heads": 24,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_hidden_layers": 40,
|
||||||
|
"patch_size": 14,
|
||||||
|
"qkv_bias": true,
|
||||||
|
"use_swiglu_ffn": true,
|
||||||
|
"image_mean": [0.485, 0.456, 0.406],
|
||||||
|
"image_std": [0.229, 0.224, 0.225]
|
||||||
|
}
|
||||||
@@ -688,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
|||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
sigma_fn = lambda t: t.neg().exp()
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
t_fn = lambda sigma: sigma.log().neg()
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
@@ -762,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
if solver_type not in {'heun', 'midpoint'}:
|
if solver_type not in {'heun', 'midpoint'}:
|
||||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
old_denoised = None
|
old_denoised = None
|
||||||
@@ -808,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
denoised_1, denoised_2 = None, None
|
denoised_1, denoised_2 = None, None
|
||||||
@@ -858,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return x
|
return x
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||||
@@ -867,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
|||||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return x
|
return x
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
@@ -876,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
|||||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return x
|
return x
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
||||||
@@ -1366,3 +1366,59 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
|
|||||||
x = x + d_bar * dt
|
x = x + d_bar * dt
|
||||||
old_d = d
|
old_d = d
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||||
|
"""
|
||||||
|
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
|
||||||
|
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||||
|
"""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
def default_noise_scaler(sigma):
|
||||||
|
return sigma * ((sigma ** 0.3).exp() + 10.0)
|
||||||
|
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
|
||||||
|
num_integration_points = 200.0
|
||||||
|
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
||||||
|
|
||||||
|
old_denoised = None
|
||||||
|
old_denoised_d = None
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
stage_used = min(max_stage, i + 1)
|
||||||
|
if sigmas[i + 1] == 0:
|
||||||
|
x = denoised
|
||||||
|
elif stage_used == 1:
|
||||||
|
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||||
|
x = r * x + (1 - r) * denoised
|
||||||
|
else:
|
||||||
|
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||||
|
x = r * x + (1 - r) * denoised
|
||||||
|
|
||||||
|
dt = sigmas[i + 1] - sigmas[i]
|
||||||
|
sigma_step_size = -dt / num_integration_points
|
||||||
|
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
|
||||||
|
scaled_pos = noise_scaler(sigma_pos)
|
||||||
|
|
||||||
|
# Stage 2
|
||||||
|
s = torch.sum(1 / scaled_pos) * sigma_step_size
|
||||||
|
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
|
||||||
|
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
|
||||||
|
|
||||||
|
if stage_used >= 3:
|
||||||
|
# Stage 3
|
||||||
|
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
|
||||||
|
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
|
||||||
|
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
|
||||||
|
old_denoised_d = denoised_d
|
||||||
|
|
||||||
|
if s_noise != 0 and sigmas[i + 1] > 0:
|
||||||
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||||
|
old_denoised = denoised
|
||||||
|
return x
|
||||||
|
|||||||
@@ -19,6 +19,10 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.autograd import Function
|
from torch.autograd import Function
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
class vector_quantize(Function):
|
class vector_quantize(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -121,15 +125,15 @@ class ResBlock(nn.Module):
|
|||||||
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||||
self.depthwise = nn.Sequential(
|
self.depthwise = nn.Sequential(
|
||||||
nn.ReplicationPad2d(1),
|
nn.ReplicationPad2d(1),
|
||||||
nn.Conv2d(c, c, kernel_size=3, groups=c)
|
ops.Conv2d(c, c, kernel_size=3, groups=c)
|
||||||
)
|
)
|
||||||
|
|
||||||
# channelwise
|
# channelwise
|
||||||
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||||
self.channelwise = nn.Sequential(
|
self.channelwise = nn.Sequential(
|
||||||
nn.Linear(c, c_hidden),
|
ops.Linear(c, c_hidden),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(c_hidden, c),
|
ops.Linear(c_hidden, c),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||||
@@ -171,16 +175,16 @@ class StageA(nn.Module):
|
|||||||
# Encoder blocks
|
# Encoder blocks
|
||||||
self.in_block = nn.Sequential(
|
self.in_block = nn.Sequential(
|
||||||
nn.PixelUnshuffle(2),
|
nn.PixelUnshuffle(2),
|
||||||
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
ops.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||||
)
|
)
|
||||||
down_blocks = []
|
down_blocks = []
|
||||||
for i in range(levels):
|
for i in range(levels):
|
||||||
if i > 0:
|
if i > 0:
|
||||||
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||||
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
||||||
down_blocks.append(block)
|
down_blocks.append(block)
|
||||||
down_blocks.append(nn.Sequential(
|
down_blocks.append(nn.Sequential(
|
||||||
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||||
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
||||||
))
|
))
|
||||||
self.down_blocks = nn.Sequential(*down_blocks)
|
self.down_blocks = nn.Sequential(*down_blocks)
|
||||||
@@ -191,7 +195,7 @@ class StageA(nn.Module):
|
|||||||
|
|
||||||
# Decoder blocks
|
# Decoder blocks
|
||||||
up_blocks = [nn.Sequential(
|
up_blocks = [nn.Sequential(
|
||||||
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
ops.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||||
)]
|
)]
|
||||||
for i in range(levels):
|
for i in range(levels):
|
||||||
for j in range(bottleneck_blocks if i == 0 else 1):
|
for j in range(bottleneck_blocks if i == 0 else 1):
|
||||||
@@ -199,11 +203,11 @@ class StageA(nn.Module):
|
|||||||
up_blocks.append(block)
|
up_blocks.append(block)
|
||||||
if i < levels - 1:
|
if i < levels - 1:
|
||||||
up_blocks.append(
|
up_blocks.append(
|
||||||
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||||
padding=1))
|
padding=1))
|
||||||
self.up_blocks = nn.Sequential(*up_blocks)
|
self.up_blocks = nn.Sequential(*up_blocks)
|
||||||
self.out_block = nn.Sequential(
|
self.out_block = nn.Sequential(
|
||||||
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||||
nn.PixelShuffle(2),
|
nn.PixelShuffle(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -232,17 +236,17 @@ class Discriminator(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
d = max(depth - 3, 3)
|
d = max(depth - 3, 3)
|
||||||
layers = [
|
layers = [
|
||||||
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||||
nn.LeakyReLU(0.2),
|
nn.LeakyReLU(0.2),
|
||||||
]
|
]
|
||||||
for i in range(depth - 1):
|
for i in range(depth - 1):
|
||||||
c_in = c_hidden // (2 ** max((d - i), 0))
|
c_in = c_hidden // (2 ** max((d - i), 0))
|
||||||
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
||||||
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||||
layers.append(nn.InstanceNorm2d(c_out))
|
layers.append(nn.InstanceNorm2d(c_out))
|
||||||
layers.append(nn.LeakyReLU(0.2))
|
layers.append(nn.LeakyReLU(0.2))
|
||||||
self.encoder = nn.Sequential(*layers)
|
self.encoder = nn.Sequential(*layers)
|
||||||
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||||
self.logits = nn.Sigmoid()
|
self.logits = nn.Sigmoid()
|
||||||
|
|
||||||
def forward(self, x, cond=None):
|
def forward(self, x, cond=None):
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ import torch
|
|||||||
import torchvision
|
import torchvision
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
# EfficientNet
|
# EfficientNet
|
||||||
class EfficientNetEncoder(nn.Module):
|
class EfficientNetEncoder(nn.Module):
|
||||||
@@ -26,7 +29,7 @@ class EfficientNetEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
||||||
self.mapper = nn.Sequential(
|
self.mapper = nn.Sequential(
|
||||||
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
ops.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||||
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
||||||
)
|
)
|
||||||
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
|
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
|
||||||
@@ -34,7 +37,7 @@ class EfficientNetEncoder(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x * 0.5 + 0.5
|
x = x * 0.5 + 0.5
|
||||||
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
|
x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype)
|
||||||
o = self.mapper(self.backbone(x))
|
o = self.mapper(self.backbone(x))
|
||||||
return o
|
return o
|
||||||
|
|
||||||
@@ -44,39 +47,39 @@ class Previewer(nn.Module):
|
|||||||
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.blocks = nn.Sequential(
|
self.blocks = nn.Sequential(
|
||||||
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden),
|
nn.BatchNorm2d(c_hidden),
|
||||||
|
|
||||||
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden),
|
nn.BatchNorm2d(c_hidden),
|
||||||
|
|
||||||
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 2),
|
nn.BatchNorm2d(c_hidden // 2),
|
||||||
|
|
||||||
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 2),
|
nn.BatchNorm2d(c_hidden // 2),
|
||||||
|
|
||||||
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 4),
|
nn.BatchNorm2d(c_hidden // 4),
|
||||||
|
|
||||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 4),
|
nn.BatchNorm2d(c_hidden // 4),
|
||||||
|
|
||||||
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 4),
|
nn.BatchNorm2d(c_hidden // 4),
|
||||||
|
|
||||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.BatchNorm2d(c_hidden // 4),
|
nn.BatchNorm2d(c_hidden // 4),
|
||||||
|
|
||||||
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
ops.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|||||||
@@ -105,7 +105,9 @@ class Modulation(nn.Module):
|
|||||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, vec: Tensor) -> tuple:
|
def forward(self, vec: Tensor) -> tuple:
|
||||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
if vec.ndim == 2:
|
||||||
|
vec = vec[:, None, :]
|
||||||
|
out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
ModulationOut(*out[:3]),
|
ModulationOut(*out[:3]),
|
||||||
@@ -113,6 +115,20 @@ class Modulation(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||||
|
if modulation_dims is None:
|
||||||
|
if m_add is not None:
|
||||||
|
return tensor * m_mult + m_add
|
||||||
|
else:
|
||||||
|
return tensor * m_mult
|
||||||
|
else:
|
||||||
|
for d in modulation_dims:
|
||||||
|
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
|
||||||
|
if m_add is not None:
|
||||||
|
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
class DoubleStreamBlock(nn.Module):
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -143,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
self.flipped_img_txt = flipped_img_txt
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
img_modulated = self.img_norm1(img)
|
img_modulated = self.img_norm1(img)
|
||||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||||
img_qkv = self.img_attn.qkv(img_modulated)
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
# prepare txt for attention
|
# prepare txt for attention
|
||||||
txt_modulated = self.txt_norm1(txt)
|
txt_modulated = self.txt_norm1(txt)
|
||||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
@@ -179,12 +195,12 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
@@ -228,9 +244,9 @@ class SingleStreamBlock(nn.Module):
|
|||||||
self.mlp_act = nn.GELU(approximate="tanh")
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
|
||||||
mod, _ = self.modulation(vec)
|
mod, _ = self.modulation(vec)
|
||||||
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
@@ -239,7 +255,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x += mod.gate * output
|
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
return x
|
return x
|
||||||
@@ -252,8 +268,11 @@ class LastLayer(nn.Module):
|
|||||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||||
|
|
||||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
|
||||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
if vec.ndim == 2:
|
||||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
vec = vec[:, None, :]
|
||||||
|
|
||||||
|
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
|
||||||
|
x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims)
|
||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -10,10 +10,11 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
|||||||
q_shape = q.shape
|
q_shape = q.shape
|
||||||
k_shape = k.shape
|
k_shape = k.shape
|
||||||
|
|
||||||
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
|
if pe is not None:
|
||||||
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
|
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||||
|
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||||
|
|
||||||
heads = q.shape[1]
|
heads = q.shape[1]
|
||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||||
@@ -36,8 +37,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
|
|
||||||
|
|
||||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
|||||||
@@ -115,8 +115,11 @@ class Flux(nn.Module):
|
|||||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
if img_ids is not None:
|
||||||
pe = self.pe_embedder(ids)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
else:
|
||||||
|
pe = None
|
||||||
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.double_blocks):
|
for i, block in enumerate(self.double_blocks):
|
||||||
|
|||||||
@@ -227,6 +227,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
timesteps: Tensor,
|
timesteps: Tensor,
|
||||||
y: Tensor,
|
y: Tensor,
|
||||||
guidance: Tensor = None,
|
guidance: Tensor = None,
|
||||||
|
guiding_frame_index=None,
|
||||||
control=None,
|
control=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
@@ -237,7 +238,17 @@ class HunyuanVideo(nn.Module):
|
|||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||||
|
|
||||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
if guiding_frame_index is not None:
|
||||||
|
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||||
|
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
|
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||||
|
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
||||||
|
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
||||||
|
modulation_dims_txt = [(0, None, 1)]
|
||||||
|
else:
|
||||||
|
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
|
modulation_dims = None
|
||||||
|
modulation_dims_txt = None
|
||||||
|
|
||||||
if self.params.guidance_embed:
|
if self.params.guidance_embed:
|
||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
@@ -264,14 +275,14 @@ class HunyuanVideo(nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_i = control.get("input")
|
control_i = control.get("input")
|
||||||
@@ -286,13 +297,13 @@ class HunyuanVideo(nn.Module):
|
|||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_o = control.get("output")
|
control_o = control.get("output")
|
||||||
@@ -303,7 +314,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
|
|
||||||
img = img[:, : img_len]
|
img = img[:, : img_len]
|
||||||
|
|
||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||||
|
|
||||||
shape = initial_shape[-3:]
|
shape = initial_shape[-3:]
|
||||||
for i in range(len(shape)):
|
for i in range(len(shape)):
|
||||||
@@ -313,7 +324,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||||
@@ -325,5 +336,5 @@ class HunyuanVideo(nn.Module):
|
|||||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
|
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from einops import rearrange
|
|||||||
import math
|
import math
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
from .symmetric_patchifier import SymmetricPatchifier
|
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||||
|
|
||||||
|
|
||||||
def get_timestep_embedding(
|
def get_timestep_embedding(
|
||||||
@@ -377,12 +377,16 @@ class LTXVModel(torch.nn.Module):
|
|||||||
|
|
||||||
positional_embedding_theta=10000.0,
|
positional_embedding_theta=10000.0,
|
||||||
positional_embedding_max_pos=[20, 2048, 2048],
|
positional_embedding_max_pos=[20, 2048, 2048],
|
||||||
|
causal_temporal_positioning=False,
|
||||||
|
vae_scale_factors=(8, 32, 32),
|
||||||
dtype=None, device=None, operations=None, **kwargs):
|
dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.generator = None
|
self.generator = None
|
||||||
|
self.vae_scale_factors = vae_scale_factors
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.out_channels = in_channels
|
self.out_channels = in_channels
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.causal_temporal_positioning = causal_temporal_positioning
|
||||||
|
|
||||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
@@ -416,42 +420,23 @@ class LTXVModel(torch.nn.Module):
|
|||||||
|
|
||||||
self.patchifier = SymmetricPatchifier(1)
|
self.patchifier = SymmetricPatchifier(1)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
indices_grid = self.patchifier.get_grid(
|
|
||||||
orig_num_frames=x.shape[2],
|
|
||||||
orig_height=x.shape[3],
|
|
||||||
orig_width=x.shape[4],
|
|
||||||
batch_size=x.shape[0],
|
|
||||||
scale_grid=((1 / frame_rate) * 8, 32, 32),
|
|
||||||
device=x.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
if guiding_latent is not None:
|
|
||||||
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
|
|
||||||
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
|
|
||||||
ts *= input_ts
|
|
||||||
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
|
|
||||||
timestep = self.patchifier.patchify(ts)
|
|
||||||
input_x = x.clone()
|
|
||||||
x[:, :, 0] = guiding_latent[:, :, 0]
|
|
||||||
if guiding_latent_noise_scale > 0:
|
|
||||||
if self.generator is None:
|
|
||||||
self.generator = torch.Generator(device=x.device).manual_seed(42)
|
|
||||||
elif self.generator.device != x.device:
|
|
||||||
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
|
|
||||||
|
|
||||||
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
|
|
||||||
scale = guiding_latent_noise_scale * (input_ts ** 2)
|
|
||||||
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
|
|
||||||
|
|
||||||
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
|
|
||||||
|
|
||||||
|
|
||||||
orig_shape = list(x.shape)
|
orig_shape = list(x.shape)
|
||||||
|
|
||||||
x = self.patchifier.patchify(x)
|
x, latent_coords = self.patchifier.patchify(x)
|
||||||
|
pixel_coords = latent_to_pixel_coords(
|
||||||
|
latent_coords=latent_coords,
|
||||||
|
scale_factors=self.vae_scale_factors,
|
||||||
|
causal_fix=self.causal_temporal_positioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
if keyframe_idxs is not None:
|
||||||
|
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
|
||||||
|
|
||||||
|
fractional_coords = pixel_coords.to(torch.float32)
|
||||||
|
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||||
|
|
||||||
x = self.patchify_proj(x)
|
x = self.patchify_proj(x)
|
||||||
timestep = timestep * 1000.0
|
timestep = timestep * 1000.0
|
||||||
@@ -459,7 +444,7 @@ class LTXVModel(torch.nn.Module):
|
|||||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||||
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
||||||
|
|
||||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
|
||||||
|
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
timestep, embedded_timestep = self.adaln_single(
|
timestep, embedded_timestep = self.adaln_single(
|
||||||
@@ -519,8 +504,4 @@ class LTXVModel(torch.nn.Module):
|
|||||||
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
||||||
)
|
)
|
||||||
|
|
||||||
if guiding_latent is not None:
|
|
||||||
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
|
|
||||||
|
|
||||||
# print("res", x)
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -6,16 +6,29 @@ from einops import rearrange
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
def latent_to_pixel_coords(
|
||||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False
|
||||||
dims_to_append = target_dims - x.ndim
|
) -> Tensor:
|
||||||
if dims_to_append < 0:
|
"""
|
||||||
raise ValueError(
|
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
|
||||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
configuration.
|
||||||
)
|
Args:
|
||||||
elif dims_to_append == 0:
|
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
|
||||||
return x
|
containing the latent corner coordinates of each token.
|
||||||
return x[(...,) + (None,) * dims_to_append]
|
scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space.
|
||||||
|
causal_fix (bool): Whether to take into account the different temporal scale
|
||||||
|
of the first frame. Default = False for backwards compatibility.
|
||||||
|
Returns:
|
||||||
|
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
||||||
|
"""
|
||||||
|
pixel_coords = (
|
||||||
|
latent_coords
|
||||||
|
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
||||||
|
)
|
||||||
|
if causal_fix:
|
||||||
|
# Fix temporal scale for first frame to 1 due to causality
|
||||||
|
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
||||||
|
return pixel_coords
|
||||||
|
|
||||||
|
|
||||||
class Patchifier(ABC):
|
class Patchifier(ABC):
|
||||||
@@ -44,29 +57,26 @@ class Patchifier(ABC):
|
|||||||
def patch_size(self):
|
def patch_size(self):
|
||||||
return self._patch_size
|
return self._patch_size
|
||||||
|
|
||||||
def get_grid(
|
def get_latent_coords(
|
||||||
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
|
self, latent_num_frames, latent_height, latent_width, batch_size, device
|
||||||
):
|
):
|
||||||
f = orig_num_frames // self._patch_size[0]
|
"""
|
||||||
h = orig_height // self._patch_size[1]
|
Return a tensor of shape [batch_size, 3, num_patches] containing the
|
||||||
w = orig_width // self._patch_size[2]
|
top-left corner latent coordinates of each latent patch.
|
||||||
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
The tensor is repeated for each batch element.
|
||||||
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
"""
|
||||||
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
latent_sample_coords = torch.meshgrid(
|
||||||
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
|
torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
|
||||||
grid = torch.stack(grid, dim=0)
|
torch.arange(0, latent_height, self._patch_size[1], device=device),
|
||||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
||||||
|
indexing="ij",
|
||||||
if scale_grid is not None:
|
)
|
||||||
for i in range(3):
|
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
||||||
if isinstance(scale_grid[i], Tensor):
|
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||||
scale = append_dims(scale_grid[i], grid.ndim - 1)
|
latent_coords = rearrange(
|
||||||
else:
|
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
|
||||||
scale = scale_grid[i]
|
)
|
||||||
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
|
return latent_coords
|
||||||
|
|
||||||
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
|
|
||||||
return grid
|
|
||||||
|
|
||||||
|
|
||||||
class SymmetricPatchifier(Patchifier):
|
class SymmetricPatchifier(Patchifier):
|
||||||
@@ -74,6 +84,8 @@ class SymmetricPatchifier(Patchifier):
|
|||||||
self,
|
self,
|
||||||
latents: Tensor,
|
latents: Tensor,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
b, _, f, h, w = latents.shape
|
||||||
|
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
|
||||||
latents = rearrange(
|
latents = rearrange(
|
||||||
latents,
|
latents,
|
||||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||||
@@ -81,7 +93,7 @@ class SymmetricPatchifier(Patchifier):
|
|||||||
p2=self._patch_size[1],
|
p2=self._patch_size[1],
|
||||||
p3=self._patch_size[2],
|
p3=self._patch_size[2],
|
||||||
)
|
)
|
||||||
return latents
|
return latents, latent_coords
|
||||||
|
|
||||||
def unpatchify(
|
def unpatchify(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ class CausalConv3d(nn.Module):
|
|||||||
stride: Union[int, Tuple[int]] = 1,
|
stride: Union[int, Tuple[int]] = 1,
|
||||||
dilation: int = 1,
|
dilation: int = 1,
|
||||||
groups: int = 1,
|
groups: int = 1,
|
||||||
|
spatial_padding_mode: str = "zeros",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -38,7 +39,7 @@ class CausalConv3d(nn.Module):
|
|||||||
stride=stride,
|
stride=stride,
|
||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
padding_mode="zeros",
|
padding_mode=spatial_padding_mode,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
|
from __future__ import annotations
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import math
|
import math
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from typing import Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||||
from .pixel_norm import PixelNorm
|
from .pixel_norm import PixelNorm
|
||||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
@@ -32,7 +34,7 @@ class Encoder(nn.Module):
|
|||||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||||
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
||||||
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
|
The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -40,12 +42,13 @@ class Encoder(nn.Module):
|
|||||||
dims: Union[int, Tuple[int, int]] = 3,
|
dims: Union[int, Tuple[int, int]] = 3,
|
||||||
in_channels: int = 3,
|
in_channels: int = 3,
|
||||||
out_channels: int = 3,
|
out_channels: int = 3,
|
||||||
blocks=[("res_x", 1)],
|
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
||||||
base_channels: int = 128,
|
base_channels: int = 128,
|
||||||
norm_num_groups: int = 32,
|
norm_num_groups: int = 32,
|
||||||
patch_size: Union[int, Tuple[int]] = 1,
|
patch_size: Union[int, Tuple[int]] = 1,
|
||||||
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
||||||
latent_log_var: str = "per_channel",
|
latent_log_var: str = "per_channel",
|
||||||
|
spatial_padding_mode: str = "zeros",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
@@ -65,6 +68,7 @@ class Encoder(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=1,
|
padding=1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.down_blocks = nn.ModuleList([])
|
self.down_blocks = nn.ModuleList([])
|
||||||
@@ -82,6 +86,7 @@ class Encoder(nn.Module):
|
|||||||
resnet_eps=1e-6,
|
resnet_eps=1e-6,
|
||||||
resnet_groups=norm_num_groups,
|
resnet_groups=norm_num_groups,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "res_x_y":
|
elif block_name == "res_x_y":
|
||||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||||
@@ -92,6 +97,7 @@ class Encoder(nn.Module):
|
|||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
groups=norm_num_groups,
|
groups=norm_num_groups,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_time":
|
elif block_name == "compress_time":
|
||||||
block = make_conv_nd(
|
block = make_conv_nd(
|
||||||
@@ -101,6 +107,7 @@ class Encoder(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=(2, 1, 1),
|
stride=(2, 1, 1),
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_space":
|
elif block_name == "compress_space":
|
||||||
block = make_conv_nd(
|
block = make_conv_nd(
|
||||||
@@ -110,6 +117,7 @@ class Encoder(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=(1, 2, 2),
|
stride=(1, 2, 2),
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_all":
|
elif block_name == "compress_all":
|
||||||
block = make_conv_nd(
|
block = make_conv_nd(
|
||||||
@@ -119,6 +127,7 @@ class Encoder(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=(2, 2, 2),
|
stride=(2, 2, 2),
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_all_x_y":
|
elif block_name == "compress_all_x_y":
|
||||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||||
@@ -129,6 +138,34 @@ class Encoder(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=(2, 2, 2),
|
stride=(2, 2, 2),
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_all_res":
|
||||||
|
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||||
|
block = SpaceToDepthDownsample(
|
||||||
|
dims=dims,
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
stride=(2, 2, 2),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_space_res":
|
||||||
|
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||||
|
block = SpaceToDepthDownsample(
|
||||||
|
dims=dims,
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
stride=(1, 2, 2),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_time_res":
|
||||||
|
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||||
|
block = SpaceToDepthDownsample(
|
||||||
|
dims=dims,
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
stride=(2, 1, 1),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown block: {block_name}")
|
raise ValueError(f"unknown block: {block_name}")
|
||||||
@@ -152,10 +189,18 @@ class Encoder(nn.Module):
|
|||||||
conv_out_channels *= 2
|
conv_out_channels *= 2
|
||||||
elif latent_log_var == "uniform":
|
elif latent_log_var == "uniform":
|
||||||
conv_out_channels += 1
|
conv_out_channels += 1
|
||||||
|
elif latent_log_var == "constant":
|
||||||
|
conv_out_channels += 1
|
||||||
elif latent_log_var != "none":
|
elif latent_log_var != "none":
|
||||||
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
||||||
self.conv_out = make_conv_nd(
|
self.conv_out = make_conv_nd(
|
||||||
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
|
dims,
|
||||||
|
output_channel,
|
||||||
|
conv_out_channels,
|
||||||
|
3,
|
||||||
|
padding=1,
|
||||||
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@@ -197,6 +242,15 @@ class Encoder(nn.Module):
|
|||||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid input shape: {sample.shape}")
|
raise ValueError(f"Invalid input shape: {sample.shape}")
|
||||||
|
elif self.latent_log_var == "constant":
|
||||||
|
sample = sample[:, :-1, ...]
|
||||||
|
approx_ln_0 = (
|
||||||
|
-30
|
||||||
|
) # this is the minimal clamp value in DiagonalGaussianDistribution objects
|
||||||
|
sample = torch.cat(
|
||||||
|
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -231,7 +285,7 @@ class Decoder(nn.Module):
|
|||||||
dims,
|
dims,
|
||||||
in_channels: int = 3,
|
in_channels: int = 3,
|
||||||
out_channels: int = 3,
|
out_channels: int = 3,
|
||||||
blocks=[("res_x", 1)],
|
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
||||||
base_channels: int = 128,
|
base_channels: int = 128,
|
||||||
layers_per_block: int = 2,
|
layers_per_block: int = 2,
|
||||||
norm_num_groups: int = 32,
|
norm_num_groups: int = 32,
|
||||||
@@ -239,6 +293,7 @@ class Decoder(nn.Module):
|
|||||||
norm_layer: str = "group_norm",
|
norm_layer: str = "group_norm",
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
timestep_conditioning: bool = False,
|
timestep_conditioning: bool = False,
|
||||||
|
spatial_padding_mode: str = "zeros",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
@@ -264,6 +319,7 @@ class Decoder(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=1,
|
padding=1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.up_blocks = nn.ModuleList([])
|
self.up_blocks = nn.ModuleList([])
|
||||||
@@ -283,6 +339,7 @@ class Decoder(nn.Module):
|
|||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
inject_noise=block_params.get("inject_noise", False),
|
inject_noise=block_params.get("inject_noise", False),
|
||||||
timestep_conditioning=timestep_conditioning,
|
timestep_conditioning=timestep_conditioning,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "attn_res_x":
|
elif block_name == "attn_res_x":
|
||||||
block = UNetMidBlock3D(
|
block = UNetMidBlock3D(
|
||||||
@@ -294,6 +351,7 @@ class Decoder(nn.Module):
|
|||||||
inject_noise=block_params.get("inject_noise", False),
|
inject_noise=block_params.get("inject_noise", False),
|
||||||
timestep_conditioning=timestep_conditioning,
|
timestep_conditioning=timestep_conditioning,
|
||||||
attention_head_dim=block_params["attention_head_dim"],
|
attention_head_dim=block_params["attention_head_dim"],
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "res_x_y":
|
elif block_name == "res_x_y":
|
||||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||||
@@ -306,14 +364,21 @@ class Decoder(nn.Module):
|
|||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
inject_noise=block_params.get("inject_noise", False),
|
inject_noise=block_params.get("inject_noise", False),
|
||||||
timestep_conditioning=False,
|
timestep_conditioning=False,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_time":
|
elif block_name == "compress_time":
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
|
dims=dims,
|
||||||
|
in_channels=input_channel,
|
||||||
|
stride=(2, 1, 1),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_space":
|
elif block_name == "compress_space":
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
dims=dims,
|
||||||
|
in_channels=input_channel,
|
||||||
|
stride=(1, 2, 2),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_all":
|
elif block_name == "compress_all":
|
||||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||||
@@ -323,6 +388,7 @@ class Decoder(nn.Module):
|
|||||||
stride=(2, 2, 2),
|
stride=(2, 2, 2),
|
||||||
residual=block_params.get("residual", False),
|
residual=block_params.get("residual", False),
|
||||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown layer: {block_name}")
|
raise ValueError(f"unknown layer: {block_name}")
|
||||||
@@ -340,7 +406,13 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
self.conv_act = nn.SiLU()
|
self.conv_act = nn.SiLU()
|
||||||
self.conv_out = make_conv_nd(
|
self.conv_out = make_conv_nd(
|
||||||
dims, output_channel, out_channels, 3, padding=1, causal=True
|
dims,
|
||||||
|
output_channel,
|
||||||
|
out_channels,
|
||||||
|
3,
|
||||||
|
padding=1,
|
||||||
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@@ -433,6 +505,12 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
||||||
resnet_groups (`int`, *optional*, defaults to 32):
|
resnet_groups (`int`, *optional*, defaults to 32):
|
||||||
The number of groups to use in the group normalization layers of the resnet blocks.
|
The number of groups to use in the group normalization layers of the resnet blocks.
|
||||||
|
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||||
|
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||||
|
inject_noise (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to inject noise into the hidden states.
|
||||||
|
timestep_conditioning (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to condition the hidden states on the timestep.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
||||||
@@ -451,6 +529,7 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
norm_layer: str = "group_norm",
|
norm_layer: str = "group_norm",
|
||||||
inject_noise: bool = False,
|
inject_noise: bool = False,
|
||||||
timestep_conditioning: bool = False,
|
timestep_conditioning: bool = False,
|
||||||
|
spatial_padding_mode: str = "zeros",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
resnet_groups = (
|
resnet_groups = (
|
||||||
@@ -476,13 +555,17 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
inject_noise=inject_noise,
|
inject_noise=inject_noise,
|
||||||
timestep_conditioning=timestep_conditioning,
|
timestep_conditioning=timestep_conditioning,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
causal: bool = True,
|
||||||
|
timestep: Optional[torch.Tensor] = None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
timestep_embed = None
|
timestep_embed = None
|
||||||
if self.timestep_conditioning:
|
if self.timestep_conditioning:
|
||||||
@@ -507,9 +590,62 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class SpaceToDepthDownsample(nn.Module):
|
||||||
|
def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
|
||||||
|
super().__init__()
|
||||||
|
self.stride = stride
|
||||||
|
self.group_size = in_channels * math.prod(stride) // out_channels
|
||||||
|
self.conv = make_conv_nd(
|
||||||
|
dims=dims,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels // math.prod(stride),
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, causal: bool = True):
|
||||||
|
if self.stride[0] == 2:
|
||||||
|
x = torch.cat(
|
||||||
|
[x[:, :, :1, :, :], x], dim=2
|
||||||
|
) # duplicate first frames for padding
|
||||||
|
|
||||||
|
# skip connection
|
||||||
|
x_in = rearrange(
|
||||||
|
x,
|
||||||
|
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||||
|
p1=self.stride[0],
|
||||||
|
p2=self.stride[1],
|
||||||
|
p3=self.stride[2],
|
||||||
|
)
|
||||||
|
x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
|
||||||
|
x_in = x_in.mean(dim=2)
|
||||||
|
|
||||||
|
# conv
|
||||||
|
x = self.conv(x, causal=causal)
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||||
|
p1=self.stride[0],
|
||||||
|
p2=self.stride[1],
|
||||||
|
p3=self.stride[2],
|
||||||
|
)
|
||||||
|
|
||||||
|
x = x + x_in
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class DepthToSpaceUpsample(nn.Module):
|
class DepthToSpaceUpsample(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
|
self,
|
||||||
|
dims,
|
||||||
|
in_channels,
|
||||||
|
stride,
|
||||||
|
residual=False,
|
||||||
|
out_channels_reduction_factor=1,
|
||||||
|
spatial_padding_mode="zeros",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
@@ -523,6 +659,7 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
self.residual = residual
|
self.residual = residual
|
||||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||||
@@ -558,7 +695,7 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
self.norm = ops.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = rearrange(x, "b c d h w -> b d h w c")
|
x = rearrange(x, "b c d h w -> b d h w c")
|
||||||
@@ -591,6 +728,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
norm_layer: str = "group_norm",
|
norm_layer: str = "group_norm",
|
||||||
inject_noise: bool = False,
|
inject_noise: bool = False,
|
||||||
timestep_conditioning: bool = False,
|
timestep_conditioning: bool = False,
|
||||||
|
spatial_padding_mode: str = "zeros",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
@@ -617,6 +755,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=1,
|
padding=1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
@@ -641,6 +780,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=1,
|
padding=1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
@@ -801,9 +941,44 @@ class processor(nn.Module):
|
|||||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||||
|
|
||||||
class VideoVAE(nn.Module):
|
class VideoVAE(nn.Module):
|
||||||
def __init__(self, version=0):
|
def __init__(self, version=0, config=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = self.guess_config(version)
|
||||||
|
|
||||||
|
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||||
|
double_z = config.get("double_z", True)
|
||||||
|
latent_log_var = config.get(
|
||||||
|
"latent_log_var", "per_channel" if double_z else "none"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder = Encoder(
|
||||||
|
dims=config["dims"],
|
||||||
|
in_channels=config.get("in_channels", 3),
|
||||||
|
out_channels=config["latent_channels"],
|
||||||
|
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
|
||||||
|
patch_size=config.get("patch_size", 1),
|
||||||
|
latent_log_var=latent_log_var,
|
||||||
|
norm_layer=config.get("norm_layer", "group_norm"),
|
||||||
|
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoder = Decoder(
|
||||||
|
dims=config["dims"],
|
||||||
|
in_channels=config["latent_channels"],
|
||||||
|
out_channels=config.get("out_channels", 3),
|
||||||
|
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||||
|
patch_size=config.get("patch_size", 1),
|
||||||
|
norm_layer=config.get("norm_layer", "group_norm"),
|
||||||
|
causal=config.get("causal_decoder", False),
|
||||||
|
timestep_conditioning=self.timestep_conditioning,
|
||||||
|
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.per_channel_statistics = processor()
|
||||||
|
|
||||||
|
def guess_config(self, version):
|
||||||
if version == 0:
|
if version == 0:
|
||||||
config = {
|
config = {
|
||||||
"_class_name": "CausalVideoAutoencoder",
|
"_class_name": "CausalVideoAutoencoder",
|
||||||
@@ -830,7 +1005,7 @@ class VideoVAE(nn.Module):
|
|||||||
"use_quant_conv": False,
|
"use_quant_conv": False,
|
||||||
"causal_decoder": False,
|
"causal_decoder": False,
|
||||||
}
|
}
|
||||||
else:
|
elif version == 1:
|
||||||
config = {
|
config = {
|
||||||
"_class_name": "CausalVideoAutoencoder",
|
"_class_name": "CausalVideoAutoencoder",
|
||||||
"dims": 3,
|
"dims": 3,
|
||||||
@@ -866,37 +1041,47 @@ class VideoVAE(nn.Module):
|
|||||||
"causal_decoder": False,
|
"causal_decoder": False,
|
||||||
"timestep_conditioning": True,
|
"timestep_conditioning": True,
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
double_z = config.get("double_z", True)
|
config = {
|
||||||
latent_log_var = config.get(
|
"_class_name": "CausalVideoAutoencoder",
|
||||||
"latent_log_var", "per_channel" if double_z else "none"
|
"dims": 3,
|
||||||
)
|
"in_channels": 3,
|
||||||
|
"out_channels": 3,
|
||||||
self.encoder = Encoder(
|
"latent_channels": 128,
|
||||||
dims=config["dims"],
|
"encoder_blocks": [
|
||||||
in_channels=config.get("in_channels", 3),
|
["res_x", {"num_layers": 4}],
|
||||||
out_channels=config["latent_channels"],
|
["compress_space_res", {"multiplier": 2}],
|
||||||
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
|
["res_x", {"num_layers": 6}],
|
||||||
patch_size=config.get("patch_size", 1),
|
["compress_time_res", {"multiplier": 2}],
|
||||||
latent_log_var=latent_log_var,
|
["res_x", {"num_layers": 6}],
|
||||||
norm_layer=config.get("norm_layer", "group_norm"),
|
["compress_all_res", {"multiplier": 2}],
|
||||||
)
|
["res_x", {"num_layers": 2}],
|
||||||
|
["compress_all_res", {"multiplier": 2}],
|
||||||
self.decoder = Decoder(
|
["res_x", {"num_layers": 2}]
|
||||||
dims=config["dims"],
|
],
|
||||||
in_channels=config["latent_channels"],
|
"decoder_blocks": [
|
||||||
out_channels=config.get("out_channels", 3),
|
["res_x", {"num_layers": 5, "inject_noise": False}],
|
||||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
["compress_all", {"residual": True, "multiplier": 2}],
|
||||||
patch_size=config.get("patch_size", 1),
|
["res_x", {"num_layers": 5, "inject_noise": False}],
|
||||||
norm_layer=config.get("norm_layer", "group_norm"),
|
["compress_all", {"residual": True, "multiplier": 2}],
|
||||||
causal=config.get("causal_decoder", False),
|
["res_x", {"num_layers": 5, "inject_noise": False}],
|
||||||
timestep_conditioning=config.get("timestep_conditioning", False),
|
["compress_all", {"residual": True, "multiplier": 2}],
|
||||||
)
|
["res_x", {"num_layers": 5, "inject_noise": False}]
|
||||||
|
],
|
||||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
"scaling_factor": 1.0,
|
||||||
self.per_channel_statistics = processor()
|
"norm_layer": "pixel_norm",
|
||||||
|
"patch_size": 4,
|
||||||
|
"latent_log_var": "uniform",
|
||||||
|
"use_quant_conv": False,
|
||||||
|
"causal_decoder": False,
|
||||||
|
"timestep_conditioning": True
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
|
frames_count = x.shape[2]
|
||||||
|
if ((frames_count - 1) % 8) != 0:
|
||||||
|
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
|
||||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||||
return self.per_channel_statistics.normalize(means)
|
return self.per_channel_statistics.normalize(means)
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ def make_conv_nd(
|
|||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
causal=False,
|
causal=False,
|
||||||
|
spatial_padding_mode="zeros",
|
||||||
|
temporal_padding_mode="zeros",
|
||||||
):
|
):
|
||||||
|
if not (spatial_padding_mode == temporal_padding_mode or causal):
|
||||||
|
raise NotImplementedError("spatial and temporal padding modes must be equal")
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
return ops.Conv2d(
|
return ops.Conv2d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@@ -28,6 +32,7 @@ def make_conv_nd(
|
|||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
if causal:
|
if causal:
|
||||||
@@ -40,6 +45,7 @@ def make_conv_nd(
|
|||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
return ops.Conv3d(
|
return ops.Conv3d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@@ -50,6 +56,7 @@ def make_conv_nd(
|
|||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif dims == (2, 1):
|
elif dims == (2, 1):
|
||||||
return DualConv3d(
|
return DualConv3d(
|
||||||
@@ -59,6 +66,7 @@ def make_conv_nd(
|
|||||||
stride=stride,
|
stride=stride,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|||||||
@@ -18,11 +18,13 @@ class DualConv3d(nn.Module):
|
|||||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
|
padding_mode="zeros",
|
||||||
):
|
):
|
||||||
super(DualConv3d, self).__init__()
|
super(DualConv3d, self).__init__()
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
self.padding_mode = padding_mode
|
||||||
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
||||||
if isinstance(kernel_size, int):
|
if isinstance(kernel_size, int):
|
||||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||||
@@ -108,6 +110,7 @@ class DualConv3d(nn.Module):
|
|||||||
self.padding1,
|
self.padding1,
|
||||||
self.dilation1,
|
self.dilation1,
|
||||||
self.groups,
|
self.groups,
|
||||||
|
padding_mode=self.padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
if skip_time_conv:
|
if skip_time_conv:
|
||||||
@@ -122,6 +125,7 @@ class DualConv3d(nn.Module):
|
|||||||
self.padding2,
|
self.padding2,
|
||||||
self.dilation2,
|
self.dilation2,
|
||||||
self.groups,
|
self.groups,
|
||||||
|
padding_mode=self.padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
@@ -137,7 +141,16 @@ class DualConv3d(nn.Module):
|
|||||||
stride1 = (self.stride1[1], self.stride1[2])
|
stride1 = (self.stride1[1], self.stride1[2])
|
||||||
padding1 = (self.padding1[1], self.padding1[2])
|
padding1 = (self.padding1[1], self.padding1[2])
|
||||||
dilation1 = (self.dilation1[1], self.dilation1[2])
|
dilation1 = (self.dilation1[1], self.dilation1[2])
|
||||||
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
|
x = F.conv2d(
|
||||||
|
x,
|
||||||
|
weight1,
|
||||||
|
self.bias1,
|
||||||
|
stride1,
|
||||||
|
padding1,
|
||||||
|
dilation1,
|
||||||
|
self.groups,
|
||||||
|
padding_mode=self.padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
_, _, h, w = x.shape
|
_, _, h, w = x.shape
|
||||||
|
|
||||||
@@ -154,7 +167,16 @@ class DualConv3d(nn.Module):
|
|||||||
stride2 = self.stride2[0]
|
stride2 = self.stride2[0]
|
||||||
padding2 = self.padding2[0]
|
padding2 = self.padding2[0]
|
||||||
dilation2 = self.dilation2[0]
|
dilation2 = self.dilation2[0]
|
||||||
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
|
x = F.conv1d(
|
||||||
|
x,
|
||||||
|
weight2,
|
||||||
|
self.bias2,
|
||||||
|
stride2,
|
||||||
|
padding2,
|
||||||
|
dilation2,
|
||||||
|
self.groups,
|
||||||
|
padding_mode=self.padding_mode,
|
||||||
|
)
|
||||||
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -24,6 +24,13 @@ if model_management.sage_attention_enabled():
|
|||||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
|
if model_management.flash_attention_enabled():
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_func
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
@@ -496,6 +503,63 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||||
|
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||||
|
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
||||||
|
|
||||||
|
|
||||||
|
@flash_attn_wrapper.register_fake
|
||||||
|
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
|
||||||
|
# Output shape is the same as q
|
||||||
|
return q.new_empty(q.shape)
|
||||||
|
except AttributeError as error:
|
||||||
|
FLASH_ATTN_ERROR = error
|
||||||
|
|
||||||
|
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||||
|
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||||
|
|
||||||
|
|
||||||
|
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
if skip_reshape:
|
||||||
|
b, _, _, dim_head = q.shape
|
||||||
|
else:
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
# add a batch dimension if there isn't already one
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
# add a heads dimension if there isn't already one
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert mask is None
|
||||||
|
out = flash_attn_wrapper(
|
||||||
|
q.transpose(1, 2),
|
||||||
|
k.transpose(1, 2),
|
||||||
|
v.transpose(1, 2),
|
||||||
|
dropout_p=0.0,
|
||||||
|
causal=False,
|
||||||
|
).transpose(1, 2)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
|
if not skip_output_reshape:
|
||||||
|
out = (
|
||||||
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
|
||||||
if model_management.sage_attention_enabled():
|
if model_management.sage_attention_enabled():
|
||||||
@@ -504,6 +568,9 @@ if model_management.sage_attention_enabled():
|
|||||||
elif model_management.xformers_enabled():
|
elif model_management.xformers_enabled():
|
||||||
logging.info("Using xformers attention")
|
logging.info("Using xformers attention")
|
||||||
optimized_attention = attention_xformers
|
optimized_attention = attention_xformers
|
||||||
|
elif model_management.flash_attention_enabled():
|
||||||
|
logging.info("Using Flash Attention")
|
||||||
|
optimized_attention = attention_flash
|
||||||
elif model_management.pytorch_attention_enabled():
|
elif model_management.pytorch_attention_enabled():
|
||||||
logging.info("Using pytorch attention")
|
logging.info("Using pytorch attention")
|
||||||
optimized_attention = attention_pytorch
|
optimized_attention = attention_pytorch
|
||||||
|
|||||||
@@ -384,6 +384,7 @@ class WanModel(torch.nn.Module):
|
|||||||
context,
|
context,
|
||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
freqs=None,
|
freqs=None,
|
||||||
|
transformer_options={},
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Forward pass through the diffusion model
|
Forward pass through the diffusion model
|
||||||
@@ -423,14 +424,18 @@ class WanModel(torch.nn.Module):
|
|||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
context = torch.concat([context_clip, context], dim=1)
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
|
||||||
# arguments
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
kwargs = dict(
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
e=e0,
|
for i, block in enumerate(self.blocks):
|
||||||
freqs=freqs,
|
if ("double_block", i) in blocks_replace:
|
||||||
context=context)
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
for block in self.blocks:
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||||
x = block(x, **kwargs)
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context)
|
||||||
|
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
@@ -439,7 +444,7 @@ class WanModel(torch.nn.Module):
|
|||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
|
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
@@ -453,7 +458,7 @@ class WanModel(torch.nn.Module):
|
|||||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
|
|
||||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
|
fp8 = model_config.optimizations.get("fp8", False)
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
@@ -161,9 +161,13 @@ class BaseModel(torch.nn.Module):
|
|||||||
extra = extra.to(dtype)
|
extra = extra.to(dtype)
|
||||||
extra_conds[o] = extra
|
extra_conds[o] = extra
|
||||||
|
|
||||||
|
t = self.process_timestep(t, x=x, **extra_conds)
|
||||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||||
|
|
||||||
|
def process_timestep(self, timestep, **kwargs):
|
||||||
|
return timestep
|
||||||
|
|
||||||
def get_dtype(self):
|
def get_dtype(self):
|
||||||
return self.diffusion_model.dtype
|
return self.diffusion_model.dtype
|
||||||
|
|
||||||
@@ -185,6 +189,11 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
if concat_latent_image.shape[1:] != noise.shape[1:]:
|
if concat_latent_image.shape[1:] != noise.shape[1:]:
|
||||||
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
if noise.ndim == 5:
|
||||||
|
if concat_latent_image.shape[-3] < noise.shape[-3]:
|
||||||
|
concat_latent_image = torch.nn.functional.pad(concat_latent_image, (0, 0, 0, 0, 0, noise.shape[-3] - concat_latent_image.shape[-3]), "constant", 0)
|
||||||
|
else:
|
||||||
|
concat_latent_image = concat_latent_image[:, :, :noise.shape[-3]]
|
||||||
|
|
||||||
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
|
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
|
||||||
|
|
||||||
@@ -213,6 +222,11 @@ class BaseModel(torch.nn.Module):
|
|||||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||||
elif ck == "mask_inverted":
|
elif ck == "mask_inverted":
|
||||||
cond_concat.append(torch.zeros_like(noise)[:, :1])
|
cond_concat.append(torch.zeros_like(noise)[:, :1])
|
||||||
|
if ck == "concat_image":
|
||||||
|
if concat_latent_image is not None:
|
||||||
|
cond_concat.append(concat_latent_image.to(device))
|
||||||
|
else:
|
||||||
|
cond_concat.append(torch.zeros_like(noise))
|
||||||
data = torch.cat(cond_concat, dim=1)
|
data = torch.cat(cond_concat, dim=1)
|
||||||
return data
|
return data
|
||||||
return None
|
return None
|
||||||
@@ -845,17 +859,26 @@ class LTXV(BaseModel):
|
|||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
guiding_latent = kwargs.get("guiding_latent", None)
|
|
||||||
if guiding_latent is not None:
|
|
||||||
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
|
|
||||||
|
|
||||||
guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None)
|
|
||||||
if guiding_latent_noise_scale is not None:
|
|
||||||
out["guiding_latent_noise_scale"] = comfy.conds.CONDConstant(guiding_latent_noise_scale)
|
|
||||||
|
|
||||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||||
|
|
||||||
|
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
|
if denoise_mask is not None:
|
||||||
|
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||||
|
|
||||||
|
keyframe_idxs = kwargs.get("keyframe_idxs", None)
|
||||||
|
if keyframe_idxs is not None:
|
||||||
|
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||||
|
if denoise_mask is None:
|
||||||
|
return timestep
|
||||||
|
return self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
|
return latent_image
|
||||||
|
|
||||||
class HunyuanVideo(BaseModel):
|
class HunyuanVideo(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||||
@@ -872,20 +895,35 @@ class HunyuanVideo(BaseModel):
|
|||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
image = kwargs.get("concat_latent_image", None)
|
|
||||||
noise = kwargs.get("noise", None)
|
|
||||||
|
|
||||||
if image is not None:
|
|
||||||
padding_shape = (noise.shape[0], 16, noise.shape[2] - 1, noise.shape[3], noise.shape[4])
|
|
||||||
latent_padding = torch.zeros(padding_shape, device=noise.device, dtype=noise.dtype)
|
|
||||||
image_latents = torch.cat([image.to(noise), latent_padding], dim=2)
|
|
||||||
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_latents))
|
|
||||||
|
|
||||||
guidance = kwargs.get("guidance", 6.0)
|
guidance = kwargs.get("guidance", 6.0)
|
||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
|
|
||||||
|
guiding_frame_index = kwargs.get("guiding_frame_index", None)
|
||||||
|
if guiding_frame_index is not None:
|
||||||
|
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||||
|
return latent_image
|
||||||
|
|
||||||
|
class HunyuanVideoI2V(HunyuanVideo):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device)
|
||||||
|
self.concat_keys = ("concat_image", "mask_inverted")
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||||
|
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
|
||||||
|
|
||||||
|
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device)
|
||||||
|
self.concat_keys = ("concat_image",)
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||||
|
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
|
||||||
|
|
||||||
class CosmosVideo(BaseModel):
|
class CosmosVideo(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)
|
||||||
@@ -935,11 +973,11 @@ class WAN21(BaseModel):
|
|||||||
self.image_to_video = image_to_video
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
def concat_cond(self, **kwargs):
|
def concat_cond(self, **kwargs):
|
||||||
if not self.image_to_video:
|
noise = kwargs.get("noise", None)
|
||||||
|
if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
image = kwargs.get("concat_latent_image", None)
|
image = kwargs.get("concat_latent_image", None)
|
||||||
noise = kwargs.get("noise", None)
|
|
||||||
device = kwargs["device"]
|
device = kwargs["device"]
|
||||||
|
|
||||||
if image is None:
|
if image is None:
|
||||||
@@ -949,6 +987,9 @@ class WAN21(BaseModel):
|
|||||||
image = self.process_latent_in(image)
|
image = self.process_latent_in(image)
|
||||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
|
|
||||||
|
if not self.image_to_video:
|
||||||
|
return image
|
||||||
|
|
||||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = torch.zeros_like(noise)[:, :4]
|
mask = torch.zeros_like(noise)[:, :4]
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import comfy.supported_models
|
import comfy.supported_models
|
||||||
import comfy.supported_models_base
|
import comfy.supported_models_base
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@@ -33,7 +34,7 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
|||||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
|
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def detect_unet_config(state_dict, key_prefix):
|
def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
|
|
||||||
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
|
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
|
||||||
@@ -210,6 +211,8 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "ltxv"
|
dit_config["image_model"] = "ltxv"
|
||||||
|
if metadata is not None and "config" in metadata:
|
||||||
|
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
|
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
|
||||||
@@ -454,8 +457,8 @@ def model_config_from_unet_config(unet_config, state_dict=None):
|
|||||||
logging.error("no match {}".format(unet_config))
|
logging.error("no match {}".format(unet_config))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None):
|
||||||
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
|
||||||
if unet_config is None:
|
if unet_config is None:
|
||||||
return None
|
return None
|
||||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||||
@@ -468,6 +471,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||||
if model_config.scaled_fp8 == torch.float32:
|
if model_config.scaled_fp8 == torch.float32:
|
||||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||||
|
if scaled_fp8_weight.nelement() == 2:
|
||||||
|
model_config.optimizations["fp8"] = False
|
||||||
|
else:
|
||||||
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
|||||||
@@ -186,12 +186,21 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
else:
|
else:
|
||||||
return mem_total
|
return mem_total
|
||||||
|
|
||||||
|
def mac_version():
|
||||||
|
try:
|
||||||
|
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info("pytorch version: {}".format(torch_version))
|
logging.info("pytorch version: {}".format(torch_version))
|
||||||
|
mac_ver = mac_version()
|
||||||
|
if mac_ver is not None:
|
||||||
|
logging.info("Mac Version {}".format(mac_ver))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -581,7 +590,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
loaded_memory = loaded_model.model_loaded_memory()
|
loaded_memory = loaded_model.model_loaded_memory()
|
||||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||||
|
|
||||||
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
@@ -921,6 +930,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
def sage_attention_enabled():
|
def sage_attention_enabled():
|
||||||
return args.use_sage_attention
|
return args.use_sage_attention
|
||||||
|
|
||||||
|
def flash_attention_enabled():
|
||||||
|
return args.use_flash_attention
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
@@ -969,12 +981,6 @@ def pytorch_attention_flash_attention():
|
|||||||
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def mac_version():
|
|
||||||
try:
|
|
||||||
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def force_upcast_attention_dtype():
|
def force_upcast_attention_dtype():
|
||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
|
|
||||||
|
|||||||
@@ -747,6 +747,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def partially_unload(self, device_to, memory_to_free=0):
|
def partially_unload(self, device_to, memory_to_free=0):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
|
hooks_unpatched = False
|
||||||
memory_freed = 0
|
memory_freed = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
unload_list = self._load_list()
|
unload_list = self._load_list()
|
||||||
@@ -770,6 +771,10 @@ class ModelPatcher:
|
|||||||
move_weight = False
|
move_weight = False
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if not hooks_unpatched:
|
||||||
|
self.unpatch_hooks()
|
||||||
|
hooks_unpatched = True
|
||||||
|
|
||||||
if bk.inplace_update:
|
if bk.inplace_update:
|
||||||
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||||
else:
|
else:
|
||||||
@@ -1089,7 +1094,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
self.unpatch_hooks()
|
|
||||||
if hooks is not None:
|
if hooks is not None:
|
||||||
model_sd_keys = list(self.model_state_dict().keys())
|
model_sd_keys = list(self.model_state_dict().keys())
|
||||||
memory_counter = None
|
memory_counter = None
|
||||||
@@ -1100,12 +1104,16 @@ class ModelPatcher:
|
|||||||
# if have cached weights for hooks, use it
|
# if have cached weights for hooks, use it
|
||||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||||
if cached_weights is not None:
|
if cached_weights is not None:
|
||||||
|
model_sd_keys_set = set(model_sd_keys)
|
||||||
for key in cached_weights:
|
for key in cached_weights:
|
||||||
if key not in model_sd_keys:
|
if key not in model_sd_keys:
|
||||||
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
||||||
continue
|
continue
|
||||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||||
|
model_sd_keys_set.remove(key)
|
||||||
|
self.unpatch_hooks(model_sd_keys_set)
|
||||||
else:
|
else:
|
||||||
|
self.unpatch_hooks()
|
||||||
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
||||||
original_weights = None
|
original_weights = None
|
||||||
if len(relevant_patches) > 0:
|
if len(relevant_patches) > 0:
|
||||||
@@ -1116,6 +1124,8 @@ class ModelPatcher:
|
|||||||
continue
|
continue
|
||||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||||
memory_counter=memory_counter)
|
memory_counter=memory_counter)
|
||||||
|
else:
|
||||||
|
self.unpatch_hooks()
|
||||||
self.current_hooks = hooks
|
self.current_hooks = hooks
|
||||||
|
|
||||||
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||||
@@ -1172,17 +1182,23 @@ class ModelPatcher:
|
|||||||
del out_weight
|
del out_weight
|
||||||
del weight
|
del weight
|
||||||
|
|
||||||
def unpatch_hooks(self) -> None:
|
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
if len(self.hook_backup) == 0:
|
if len(self.hook_backup) == 0:
|
||||||
self.current_hooks = None
|
self.current_hooks = None
|
||||||
return
|
return
|
||||||
keys = list(self.hook_backup.keys())
|
keys = list(self.hook_backup.keys())
|
||||||
for k in keys:
|
if whitelist_keys_set:
|
||||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
for k in keys:
|
||||||
|
if k in whitelist_keys_set:
|
||||||
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
|
self.hook_backup.pop(k)
|
||||||
|
else:
|
||||||
|
for k in keys:
|
||||||
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
|
|
||||||
self.hook_backup.clear()
|
self.hook_backup.clear()
|
||||||
self.current_hooks = None
|
self.current_hooks = None
|
||||||
|
|
||||||
def clean_hooks(self):
|
def clean_hooks(self):
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import comfy.float
|
import comfy.float
|
||||||
@@ -308,6 +309,7 @@ class fp8_ops(manual_cast):
|
|||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||||
|
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
||||||
class scaled_fp8_op(manual_cast):
|
class scaled_fp8_op(manual_cast):
|
||||||
class Linear(manual_cast.Linear):
|
class Linear(manual_cast.Linear):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@@ -358,7 +360,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
if scaled_fp8 is not None:
|
if scaled_fp8 is not None:
|
||||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
|
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
fp8_compute and
|
fp8_compute and
|
||||||
|
|||||||
@@ -19,6 +19,12 @@ import comfy.hooks
|
|||||||
import scipy.stats
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
|
def add_area_dims(area, num_dims):
|
||||||
|
while (len(area) // 2) < num_dims:
|
||||||
|
area = [2147483648] + area[:len(area) // 2] + [0] + area[len(area) // 2:]
|
||||||
|
return area
|
||||||
|
|
||||||
def get_area_and_mult(conds, x_in, timestep_in):
|
def get_area_and_mult(conds, x_in, timestep_in):
|
||||||
dims = tuple(x_in.shape[2:])
|
dims = tuple(x_in.shape[2:])
|
||||||
area = None
|
area = None
|
||||||
@@ -34,8 +40,9 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
|||||||
return None
|
return None
|
||||||
if 'area' in conds:
|
if 'area' in conds:
|
||||||
area = list(conds['area'])
|
area = list(conds['area'])
|
||||||
while (len(area) // 2) < len(dims):
|
area = add_area_dims(area, len(dims))
|
||||||
area = [2147483648] + area[:len(area) // 2] + [0] + area[len(area) // 2:]
|
if (len(area) // 2) > len(dims):
|
||||||
|
area = area[:len(dims)] + area[len(area) // 2:(len(area) // 2) + len(dims)]
|
||||||
|
|
||||||
if 'strength' in conds:
|
if 'strength' in conds:
|
||||||
strength = conds['strength']
|
strength = conds['strength']
|
||||||
@@ -53,7 +60,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
|||||||
if "mask_strength" in conds:
|
if "mask_strength" in conds:
|
||||||
mask_strength = conds["mask_strength"]
|
mask_strength = conds["mask_strength"]
|
||||||
mask = conds['mask']
|
mask = conds['mask']
|
||||||
assert(mask.shape[1:] == x_in.shape[2:])
|
assert (mask.shape[1:] == x_in.shape[2:])
|
||||||
|
|
||||||
mask = mask[:input_x.shape[0]]
|
mask = mask[:input_x.shape[0]]
|
||||||
if area is not None:
|
if area is not None:
|
||||||
@@ -67,16 +74,17 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
|||||||
mult = mask * strength
|
mult = mask * strength
|
||||||
|
|
||||||
if 'mask' not in conds and area is not None:
|
if 'mask' not in conds and area is not None:
|
||||||
rr = 8
|
fuzz = 8
|
||||||
for i in range(len(dims)):
|
for i in range(len(dims)):
|
||||||
|
rr = min(fuzz, mult.shape[2 + i] // 4)
|
||||||
if area[len(dims) + i] != 0:
|
if area[len(dims) + i] != 0:
|
||||||
for t in range(rr):
|
for t in range(rr):
|
||||||
m = mult.narrow(i + 2, t, 1)
|
m = mult.narrow(i + 2, t, 1)
|
||||||
m *= ((1.0/rr) * (t + 1))
|
m *= ((1.0 / rr) * (t + 1))
|
||||||
if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]:
|
if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]:
|
||||||
for t in range(rr):
|
for t in range(rr):
|
||||||
m = mult.narrow(i + 2, area[i] - 1 - t, 1)
|
m = mult.narrow(i + 2, area[i] - 1 - t, 1)
|
||||||
m *= ((1.0/rr) * (t + 1))
|
m *= ((1.0 / rr) * (t + 1))
|
||||||
|
|
||||||
conditioning = {}
|
conditioning = {}
|
||||||
model_conds = conds["model_conds"]
|
model_conds = conds["model_conds"]
|
||||||
@@ -551,25 +559,37 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
|||||||
logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.")
|
logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.")
|
||||||
return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device)
|
return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device)
|
||||||
|
|
||||||
def create_cond_with_same_area_if_none(conds, c): #TODO: handle dim != 2
|
def create_cond_with_same_area_if_none(conds, c):
|
||||||
if 'area' not in c:
|
if 'area' not in c:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def area_inside(a, area_cmp):
|
||||||
|
a = add_area_dims(a, len(area_cmp) // 2)
|
||||||
|
area_cmp = add_area_dims(area_cmp, len(a) // 2)
|
||||||
|
|
||||||
|
a_l = len(a) // 2
|
||||||
|
area_cmp_l = len(area_cmp) // 2
|
||||||
|
for i in range(min(a_l, area_cmp_l)):
|
||||||
|
if a[a_l + i] < area_cmp[area_cmp_l + i]:
|
||||||
|
return False
|
||||||
|
for i in range(min(a_l, area_cmp_l)):
|
||||||
|
if (a[i] + a[a_l + i]) > (area_cmp[i] + area_cmp[area_cmp_l + i]):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
c_area = c['area']
|
c_area = c['area']
|
||||||
smallest = None
|
smallest = None
|
||||||
for x in conds:
|
for x in conds:
|
||||||
if 'area' in x:
|
if 'area' in x:
|
||||||
a = x['area']
|
a = x['area']
|
||||||
if c_area[2] >= a[2] and c_area[3] >= a[3]:
|
if area_inside(c_area, a):
|
||||||
if a[0] + a[2] >= c_area[0] + c_area[2]:
|
if smallest is None:
|
||||||
if a[1] + a[3] >= c_area[1] + c_area[3]:
|
smallest = x
|
||||||
if smallest is None:
|
elif 'area' not in smallest:
|
||||||
smallest = x
|
smallest = x
|
||||||
elif 'area' not in smallest:
|
else:
|
||||||
smallest = x
|
if math.prod(smallest['area'][:len(smallest['area']) // 2]) > math.prod(a[:len(a) // 2]):
|
||||||
else:
|
smallest = x
|
||||||
if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]:
|
|
||||||
smallest = x
|
|
||||||
else:
|
else:
|
||||||
if smallest is None:
|
if smallest is None:
|
||||||
smallest = x
|
smallest = x
|
||||||
@@ -690,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
|||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
"gradient_estimation"]
|
"gradient_estimation", "er_sde"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
|||||||
39
comfy/sd.py
39
comfy/sd.py
@@ -1,4 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
@@ -134,8 +135,8 @@ class CLIP:
|
|||||||
def clip_layer(self, layer_idx):
|
def clip_layer(self, layer_idx):
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
def tokenize(self, text, return_word_ids=False):
|
def tokenize(self, text, return_word_ids=False, **kwargs):
|
||||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
|
|
||||||
def add_hooks_to_dict(self, pooled_dict: dict[str]):
|
def add_hooks_to_dict(self, pooled_dict: dict[str]):
|
||||||
if self.apply_hooks_to_conds:
|
if self.apply_hooks_to_conds:
|
||||||
@@ -249,7 +250,7 @@ class CLIP:
|
|||||||
return self.patcher.get_key_patches()
|
return self.patcher.get_key_patches()
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, sd=None, device=None, config=None, dtype=None):
|
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
|
|
||||||
@@ -357,7 +358,12 @@ class VAE:
|
|||||||
version = 0
|
version = 0
|
||||||
elif tensor_conv1.shape[0] == 1024:
|
elif tensor_conv1.shape[0] == 1024:
|
||||||
version = 1
|
version = 1
|
||||||
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version)
|
if "encoder.down_blocks.1.conv.conv.bias" in sd:
|
||||||
|
version = 2
|
||||||
|
vae_config = None
|
||||||
|
if metadata is not None and "config" in metadata:
|
||||||
|
vae_config = json.loads(metadata["config"]).get("vae", None)
|
||||||
|
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config)
|
||||||
self.latent_channels = 128
|
self.latent_channels = 128
|
||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
@@ -434,6 +440,10 @@ class VAE:
|
|||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||||
|
|
||||||
|
def throw_exception_if_invalid(self):
|
||||||
|
if self.first_stage_model is None:
|
||||||
|
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||||
|
|
||||||
def vae_encode_crop_pixels(self, pixels):
|
def vae_encode_crop_pixels(self, pixels):
|
||||||
downscale_ratio = self.spacial_compression_encode()
|
downscale_ratio = self.spacial_compression_encode()
|
||||||
|
|
||||||
@@ -489,6 +499,7 @@ class VAE:
|
|||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = None
|
pixel_samples = None
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
@@ -519,6 +530,7 @@ class VAE:
|
|||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
|
self.throw_exception_if_invalid()
|
||||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
dims = samples.ndim - 2
|
dims = samples.ndim - 2
|
||||||
@@ -547,6 +559,7 @@ class VAE:
|
|||||||
return output.movedim(1, -1)
|
return output.movedim(1, -1)
|
||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||||
@@ -579,6 +592,7 @@ class VAE:
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
dims = self.latent_dim
|
dims = self.latent_dim
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
@@ -873,13 +887,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
return (model, clip, vae)
|
return (model, clip, vae)
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
|
||||||
if out is None:
|
if out is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
||||||
clip = None
|
clip = None
|
||||||
clipvision = None
|
clipvision = None
|
||||||
vae = None
|
vae = None
|
||||||
@@ -891,9 +905,14 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
return None
|
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||||
|
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
|
||||||
|
if diffusion_model is None:
|
||||||
|
return None
|
||||||
|
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||||
|
|
||||||
|
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if model_config.scaled_fp8 is not None:
|
if model_config.scaled_fp8 is not None:
|
||||||
@@ -920,7 +939,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
if output_vae:
|
if output_vae:
|
||||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||||
vae = VAE(sd=vae_sd)
|
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
clip_target = model_config.clip_target(state_dict=sd)
|
clip_target = model_config.clip_target(state_dict=sd)
|
||||||
|
|||||||
@@ -158,71 +158,93 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
self.layer_idx = self.options_default[1]
|
self.layer_idx = self.options_default[1]
|
||||||
self.return_projected_pooled = self.options_default[2]
|
self.return_projected_pooled = self.options_default[2]
|
||||||
|
|
||||||
def set_up_textual_embeddings(self, tokens, current_embeds):
|
def process_tokens(self, tokens, device):
|
||||||
out_tokens = []
|
end_token = self.special_tokens.get("end", None)
|
||||||
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
if end_token is None:
|
||||||
embedding_weights = []
|
cmp_token = self.special_tokens.get("pad", -1)
|
||||||
|
else:
|
||||||
|
cmp_token = end_token
|
||||||
|
|
||||||
|
embeds_out = []
|
||||||
|
attention_masks = []
|
||||||
|
num_tokens = []
|
||||||
|
|
||||||
for x in tokens:
|
for x in tokens:
|
||||||
|
attention_mask = []
|
||||||
tokens_temp = []
|
tokens_temp = []
|
||||||
|
other_embeds = []
|
||||||
|
eos = False
|
||||||
|
index = 0
|
||||||
for y in x:
|
for y in x:
|
||||||
if isinstance(y, numbers.Integral):
|
if isinstance(y, numbers.Integral):
|
||||||
tokens_temp += [int(y)]
|
if eos:
|
||||||
else:
|
attention_mask.append(0)
|
||||||
if y.shape[0] == current_embeds.weight.shape[1]:
|
|
||||||
embedding_weights += [y]
|
|
||||||
tokens_temp += [next_new_token]
|
|
||||||
next_new_token += 1
|
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1]))
|
attention_mask.append(1)
|
||||||
while len(tokens_temp) < len(x):
|
token = int(y)
|
||||||
tokens_temp += [self.special_tokens["pad"]]
|
tokens_temp += [token]
|
||||||
out_tokens += [tokens_temp]
|
if not eos and token == cmp_token:
|
||||||
|
if end_token is None:
|
||||||
|
attention_mask[-1] = 0
|
||||||
|
eos = True
|
||||||
|
else:
|
||||||
|
other_embeds.append((index, y))
|
||||||
|
index += 1
|
||||||
|
|
||||||
n = token_dict_size
|
tokens_embed = torch.tensor([tokens_temp], device=device, dtype=torch.long)
|
||||||
if len(embedding_weights) > 0:
|
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
|
||||||
new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
index = 0
|
||||||
new_embedding.weight[:token_dict_size] = current_embeds.weight
|
pad_extra = 0
|
||||||
for x in embedding_weights:
|
for o in other_embeds:
|
||||||
new_embedding.weight[n] = x
|
emb = o[1]
|
||||||
n += 1
|
if torch.is_tensor(emb):
|
||||||
self.transformer.set_input_embeddings(new_embedding)
|
emb = {"type": "embedding", "data": emb}
|
||||||
|
|
||||||
processed_tokens = []
|
emb_type = emb.get("type", None)
|
||||||
for x in out_tokens:
|
if emb_type == "embedding":
|
||||||
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
|
emb = emb.get("data", None)
|
||||||
|
else:
|
||||||
|
if hasattr(self.transformer, "preprocess_embed"):
|
||||||
|
emb = self.transformer.preprocess_embed(emb, device=device)
|
||||||
|
else:
|
||||||
|
emb = None
|
||||||
|
|
||||||
return processed_tokens
|
if emb is None:
|
||||||
|
index += -1
|
||||||
|
continue
|
||||||
|
|
||||||
|
ind = index + o[0]
|
||||||
|
emb = emb.view(1, -1, emb.shape[-1]).to(device=device, dtype=torch.float32)
|
||||||
|
emb_shape = emb.shape[1]
|
||||||
|
if emb.shape[-1] == tokens_embed.shape[-1]:
|
||||||
|
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
|
||||||
|
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
|
||||||
|
index += emb_shape - 1
|
||||||
|
else:
|
||||||
|
index += -1
|
||||||
|
pad_extra += emb_shape
|
||||||
|
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1]))
|
||||||
|
|
||||||
|
if pad_extra > 0:
|
||||||
|
padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32)
|
||||||
|
tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1)
|
||||||
|
attention_mask = attention_mask + [0] * pad_extra
|
||||||
|
|
||||||
|
embeds_out.append(tokens_embed)
|
||||||
|
attention_masks.append(attention_mask)
|
||||||
|
num_tokens.append(sum(attention_mask))
|
||||||
|
|
||||||
|
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
|
||||||
|
|
||||||
def forward(self, tokens):
|
def forward(self, tokens):
|
||||||
backup_embeds = self.transformer.get_input_embeddings()
|
device = self.transformer.get_input_embeddings().weight.device
|
||||||
device = backup_embeds.weight.device
|
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
|
||||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
|
||||||
tokens = torch.LongTensor(tokens).to(device)
|
|
||||||
|
|
||||||
attention_mask = None
|
|
||||||
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
|
|
||||||
attention_mask = torch.zeros_like(tokens)
|
|
||||||
end_token = self.special_tokens.get("end", None)
|
|
||||||
if end_token is None:
|
|
||||||
cmp_token = self.special_tokens.get("pad", -1)
|
|
||||||
else:
|
|
||||||
cmp_token = end_token
|
|
||||||
|
|
||||||
for x in range(attention_mask.shape[0]):
|
|
||||||
for y in range(attention_mask.shape[1]):
|
|
||||||
attention_mask[x, y] = 1
|
|
||||||
if tokens[x, y] == cmp_token:
|
|
||||||
if end_token is None:
|
|
||||||
attention_mask[x, y] = 0
|
|
||||||
break
|
|
||||||
|
|
||||||
attention_mask_model = None
|
attention_mask_model = None
|
||||||
if self.enable_attention_masks:
|
if self.enable_attention_masks:
|
||||||
attention_mask_model = attention_mask
|
attention_mask_model = attention_mask
|
||||||
|
|
||||||
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
||||||
self.transformer.set_input_embeddings(backup_embeds)
|
|
||||||
|
|
||||||
if self.layer == "last":
|
if self.layer == "last":
|
||||||
z = outputs[0].float()
|
z = outputs[0].float()
|
||||||
@@ -482,7 +504,7 @@ class SDTokenizer:
|
|||||||
return (embed, leftover)
|
return (embed, leftover)
|
||||||
|
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
'''
|
'''
|
||||||
Takes a prompt and converts it to a list of (token, weight, word id) elements.
|
Takes a prompt and converts it to a list of (token, weight, word id) elements.
|
||||||
Tokens can both be integer tokens and pre computed CLIP tensors.
|
Tokens can both be integer tokens and pre computed CLIP tensors.
|
||||||
@@ -596,7 +618,7 @@ class SD1Tokenizer:
|
|||||||
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
||||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
|
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class SDXLTokenizer:
|
|||||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||||
|
|||||||
@@ -762,7 +762,7 @@ class LTXV(supported_models_base.BASE):
|
|||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.LTXV
|
latent_format = latent_formats.LTXV
|
||||||
|
|
||||||
memory_usage_factor = 2.7
|
memory_usage_factor = 5.5 # TODO: img2vid is about 2x vs txt2vid
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
@@ -826,6 +826,26 @@ class HunyuanVideo(supported_models_base.BASE):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
||||||
|
|
||||||
|
class HunyuanVideoI2V(HunyuanVideo):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "hunyuan_video",
|
||||||
|
"in_channels": 33,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.HunyuanVideoI2V(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "hunyuan_video",
|
||||||
|
"in_channels": 32,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.HunyuanVideoSkyreelsI2V(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class CosmosT2V(supported_models_base.BASE):
|
class CosmosT2V(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "cosmos",
|
"image_model": "cosmos",
|
||||||
@@ -911,7 +931,7 @@ class WAN21_T2V(supported_models_base.BASE):
|
|||||||
|
|
||||||
memory_usage_factor = 1.0
|
memory_usage_factor = 1.0
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
@@ -939,6 +959,6 @@ class WAN21_I2V(WAN21_T2V):
|
|||||||
out = model_base.WAN21(self, image_to_video=True, device=device)
|
out = model_base.WAN21(self, image_to_video=True, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@@ -93,8 +93,11 @@ class BertEmbeddings(torch.nn.Module):
|
|||||||
|
|
||||||
self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, input_tokens, token_type_ids=None, dtype=None):
|
def forward(self, input_tokens, embeds=None, token_type_ids=None, dtype=None):
|
||||||
x = self.word_embeddings(input_tokens, out_dtype=dtype)
|
if embeds is not None:
|
||||||
|
x = embeds
|
||||||
|
else:
|
||||||
|
x = self.word_embeddings(input_tokens, out_dtype=dtype)
|
||||||
x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x)
|
x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x)
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
|
x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
|
||||||
@@ -113,8 +116,8 @@ class BertModel_(torch.nn.Module):
|
|||||||
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
|
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
|
||||||
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
|
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
|
||||||
|
|
||||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||||
x = self.embeddings(input_tokens, dtype=dtype)
|
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
|
||||||
mask = None
|
mask = None
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class FluxTokenizer:
|
|||||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import comfy.text_encoders.llama
|
|||||||
from transformers import LlamaTokenizerFast
|
from transformers import LlamaTokenizerFast
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
|
||||||
def llama_detect(state_dict, prefix=""):
|
def llama_detect(state_dict, prefix=""):
|
||||||
@@ -22,7 +23,7 @@ def llama_detect(state_dict, prefix=""):
|
|||||||
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length)
|
||||||
|
|
||||||
class LLAMAModel(sd1_clip.SDClipModel):
|
class LLAMAModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||||
@@ -38,15 +39,26 @@ class HunyuanVideoTokenizer:
|
|||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens
|
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
|
||||||
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
|
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||||
|
|
||||||
llama_text = "{}{}".format(self.llama_template, text)
|
if llama_template is None:
|
||||||
out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids)
|
llama_text = self.llama_template.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = llama_template.format(text)
|
||||||
|
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids)
|
||||||
|
embed_count = 0
|
||||||
|
for r in llama_text_tokens:
|
||||||
|
for i in range(len(r)):
|
||||||
|
if r[i][0] == 128257:
|
||||||
|
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
||||||
|
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image", "image_interleave": image_interleave},) + r[i][1:]
|
||||||
|
embed_count += 1
|
||||||
|
out["llama"] = llama_text_tokens
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
@@ -80,20 +92,51 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
|||||||
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
|
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||||
|
|
||||||
template_end = 0
|
template_end = 0
|
||||||
for i, v in enumerate(token_weight_pairs_llama[0]):
|
extra_template_end = 0
|
||||||
if v[0] == 128007: # <|end_header_id|>
|
extra_sizes = 0
|
||||||
template_end = i
|
user_end = 9999999999999
|
||||||
|
images = []
|
||||||
|
|
||||||
|
tok_pairs = token_weight_pairs_llama[0]
|
||||||
|
for i, v in enumerate(tok_pairs):
|
||||||
|
elem = v[0]
|
||||||
|
if not torch.is_tensor(elem):
|
||||||
|
if isinstance(elem, numbers.Integral):
|
||||||
|
if elem == 128006:
|
||||||
|
if tok_pairs[i + 1][0] == 882:
|
||||||
|
if tok_pairs[i + 2][0] == 128007:
|
||||||
|
template_end = i + 2
|
||||||
|
user_end = -1
|
||||||
|
if elem == 128009 and user_end == -1:
|
||||||
|
user_end = i + 1
|
||||||
|
else:
|
||||||
|
if elem.get("original_type") == "image":
|
||||||
|
elem_size = elem.get("data").shape[0]
|
||||||
|
if template_end > 0:
|
||||||
|
if user_end == -1:
|
||||||
|
extra_template_end += elem_size - 1
|
||||||
|
else:
|
||||||
|
image_start = i + extra_sizes
|
||||||
|
image_end = i + elem_size + extra_sizes
|
||||||
|
images.append((image_start, image_end, elem.get("image_interleave", 1)))
|
||||||
|
extra_sizes += elem_size - 1
|
||||||
|
|
||||||
if llama_out.shape[1] > (template_end + 2):
|
if llama_out.shape[1] > (template_end + 2):
|
||||||
if token_weight_pairs_llama[0][template_end + 1][0] == 271:
|
if tok_pairs[template_end + 1][0] == 271:
|
||||||
template_end += 2
|
template_end += 2
|
||||||
llama_out = llama_out[:, template_end:]
|
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
||||||
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:]
|
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
||||||
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
|
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
|
||||||
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
|
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
|
||||||
|
|
||||||
|
if len(images) > 0:
|
||||||
|
out = []
|
||||||
|
for i in images:
|
||||||
|
out.append(llama_out[:, i[0]: i[1]: i[2]])
|
||||||
|
llama_output = torch.cat(out + [llama_output], dim=1)
|
||||||
|
|
||||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
return llama_out, l_pooled, llama_extra_out
|
return llama_output, l_pooled, llama_extra_out
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class HyditTokenizer:
|
|||||||
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
|
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
|
||||||
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
|
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
|
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
|
||||||
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
|
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
|
||||||
|
|||||||
@@ -241,8 +241,11 @@ class Llama2_(nn.Module):
|
|||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||||
x = self.embed_tokens(x, out_dtype=dtype)
|
if embeds is not None:
|
||||||
|
x = embeds
|
||||||
|
else:
|
||||||
|
x = self.embed_tokens(x, out_dtype=dtype)
|
||||||
|
|
||||||
if self.normalize_in:
|
if self.normalize_in:
|
||||||
x *= self.config.hidden_size ** 0.5
|
x *= self.config.hidden_size ** 0.5
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class SD3Tokenizer:
|
|||||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||||
|
|||||||
@@ -239,8 +239,11 @@ class T5(torch.nn.Module):
|
|||||||
def set_input_embeddings(self, embeddings):
|
def set_input_embeddings(self, embeddings):
|
||||||
self.shared = embeddings
|
self.shared = embeddings
|
||||||
|
|
||||||
def forward(self, input_ids, *args, **kwargs):
|
def forward(self, input_ids, attention_mask, embeds=None, num_tokens=None, **kwargs):
|
||||||
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
|
if input_ids is None:
|
||||||
|
x = embeds
|
||||||
|
else:
|
||||||
|
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
|
||||||
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
||||||
x = torch.nan_to_num(x) #Fix for fp8 T5 base
|
x = torch.nan_to_num(x) #Fix for fp8 T5 base
|
||||||
return self.encoder(x, *args, **kwargs)
|
return self.encoder(x, attention_mask=attention_mask, **kwargs)
|
||||||
|
|||||||
@@ -46,12 +46,18 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
|||||||
else:
|
else:
|
||||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||||
if device is None:
|
if device is None:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
metadata = None
|
||||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||||
try:
|
try:
|
||||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||||
|
sd = {}
|
||||||
|
for k in f.keys():
|
||||||
|
sd[k] = f.get_tensor(k)
|
||||||
|
if return_metadata:
|
||||||
|
metadata = f.metadata()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 0:
|
if len(e.args) > 0:
|
||||||
message = e.args[0]
|
message = e.args[0]
|
||||||
@@ -77,7 +83,7 @@ def load_torch_file(ckpt, safe_load=False, device=None):
|
|||||||
sd = pl_sd
|
sd = pl_sd
|
||||||
else:
|
else:
|
||||||
sd = pl_sd
|
sd = pl_sd
|
||||||
return sd
|
return (sd, metadata) if return_metadata else sd
|
||||||
|
|
||||||
def save_torch_file(sd, ckpt, metadata=None):
|
def save_torch_file(sd, ckpt, metadata=None):
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@@ -10,6 +12,7 @@ import random
|
|||||||
import hashlib
|
import hashlib
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
from comfy.comfy_types import FileLocator
|
||||||
|
|
||||||
class EmptyLatentAudio:
|
class EmptyLatentAudio:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -164,7 +167,7 @@ class SaveAudio:
|
|||||||
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
filename_prefix += self.prefix_append
|
filename_prefix += self.prefix_append
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
results = list()
|
results: list[FileLocator] = []
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
if not args.disable_metadata:
|
if not args.disable_metadata:
|
||||||
|
|||||||
@@ -454,7 +454,7 @@ class SamplerCustom:
|
|||||||
return {"required":
|
return {"required":
|
||||||
{"model": ("MODEL",),
|
{"model": ("MODEL",),
|
||||||
"add_noise": ("BOOLEAN", {"default": True}),
|
"add_noise": ("BOOLEAN", {"default": True}),
|
||||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}),
|
||||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||||
"positive": ("CONDITIONING", ),
|
"positive": ("CONDITIONING", ),
|
||||||
"negative": ("CONDITIONING", ),
|
"negative": ("CONDITIONING", ),
|
||||||
@@ -605,10 +605,16 @@ class DisableNoise:
|
|||||||
class RandomNoise(DisableNoise):
|
class RandomNoise(DisableNoise):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required":{
|
return {
|
||||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"required": {
|
||||||
}
|
"noise_seed": ("INT", {
|
||||||
}
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xffffffffffffffff,
|
||||||
|
"control_after_generate": True,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def get_noise(self, noise_seed):
|
def get_noise(self, noise_seed):
|
||||||
return (Noise_RandomNoise(noise_seed),)
|
return (Noise_RandomNoise(noise_seed),)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import nodes
|
import nodes
|
||||||
|
import node_helpers
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
@@ -38,7 +39,83 @@ class EmptyHunyuanLatentVideo:
|
|||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
return ({"samples":latent}, )
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
|
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||||
|
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||||
|
"1. The main content and theme of the video."
|
||||||
|
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||||
|
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||||
|
"4. background environment, light, style and atmosphere."
|
||||||
|
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
class TextEncodeHunyuanVideo_ImageToVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"clip": ("CLIP", ),
|
||||||
|
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||||
|
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning"
|
||||||
|
|
||||||
|
def encode(self, clip, clip_vision_output, prompt, image_interleave):
|
||||||
|
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
|
||||||
|
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||||
|
|
||||||
|
class HunyuanImageToVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
|
||||||
|
},
|
||||||
|
"optional": {"start_image": ("IMAGE", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "latent")
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
|
def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
|
||||||
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
out_latent = {}
|
||||||
|
|
||||||
|
if start_image is not None:
|
||||||
|
start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
|
||||||
|
concat_latent_image = vae.encode(start_image)
|
||||||
|
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||||
|
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||||
|
|
||||||
|
if guidance_type == "v1 (concat)":
|
||||||
|
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
|
||||||
|
else:
|
||||||
|
cond = {'guiding_frame_index': 0}
|
||||||
|
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
|
||||||
|
out_latent["noise_mask"] = mask
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, cond)
|
||||||
|
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return (positive, out_latent)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||||
|
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
||||||
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
||||||
|
"HunyuanImageToVideo": HunyuanImageToVideo,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
@@ -9,6 +11,8 @@ import numpy as np
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from comfy.comfy_types import FileLocator
|
||||||
|
|
||||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||||
|
|
||||||
class ImageCrop:
|
class ImageCrop:
|
||||||
@@ -99,7 +103,7 @@ class SaveAnimatedWEBP:
|
|||||||
method = self.methods.get(method)
|
method = self.methods.get(method)
|
||||||
filename_prefix += self.prefix_append
|
filename_prefix += self.prefix_append
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||||
results = list()
|
results: list[FileLocator] = []
|
||||||
pil_images = []
|
pil_images = []
|
||||||
for image in images:
|
for image in images:
|
||||||
i = 255. * image.cpu().numpy()
|
i = 255. * image.cpu().numpy()
|
||||||
|
|||||||
@@ -19,8 +19,6 @@ class Load3D():
|
|||||||
"image": ("LOAD_3D", {}),
|
"image": ("LOAD_3D", {}),
|
||||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
|
||||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||||
@@ -55,8 +53,6 @@ class Load3DAnimation():
|
|||||||
"image": ("LOAD_3D_ANIMATION", {}),
|
"image": ("LOAD_3D_ANIMATION", {}),
|
||||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
|
||||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||||
@@ -82,8 +78,6 @@ class Preview3D():
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
|
||||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
@@ -102,8 +96,6 @@ class Preview3DAnimation():
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||||
"material": (["original", "normal", "wireframe", "depth"],),
|
|
||||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|||||||
@@ -1,9 +1,14 @@
|
|||||||
|
import io
|
||||||
import nodes
|
import nodes
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
|
import comfy.utils
|
||||||
import math
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import av
|
||||||
|
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||||
|
|
||||||
class EmptyLTXVLatentVideo:
|
class EmptyLTXVLatentVideo:
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -33,7 +38,6 @@ class LTXVImgToVideo:
|
|||||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||||
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
"image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 1.0, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."})
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
@@ -42,16 +46,219 @@ class LTXVImgToVideo:
|
|||||||
CATEGORY = "conditioning/video_models"
|
CATEGORY = "conditioning/video_models"
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
def generate(self, positive, negative, image, vae, width, height, length, batch_size, image_noise_scale):
|
def generate(self, positive, negative, image, vae, width, height, length, batch_size):
|
||||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
encode_pixels = pixels[:, :, :, :3]
|
encode_pixels = pixels[:, :, :, :3]
|
||||||
t = vae.encode(encode_pixels)
|
t = vae.encode(encode_pixels)
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
|
|
||||||
|
|
||||||
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||||
latent[:, :, :t.shape[2]] = t
|
latent[:, :, :t.shape[2]] = t
|
||||||
return (positive, negative, {"samples": latent}, )
|
|
||||||
|
conditioning_latent_frames_mask = torch.ones(
|
||||||
|
(batch_size, 1, latent.shape[2], 1, 1),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=latent.device,
|
||||||
|
)
|
||||||
|
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0
|
||||||
|
|
||||||
|
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, )
|
||||||
|
|
||||||
|
|
||||||
|
def conditioning_get_any_value(conditioning, key, default=None):
|
||||||
|
for t in conditioning:
|
||||||
|
if key in t[1]:
|
||||||
|
return t[1][key]
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def get_noise_mask(latent):
|
||||||
|
noise_mask = latent.get("noise_mask", None)
|
||||||
|
latent_image = latent["samples"]
|
||||||
|
if noise_mask is None:
|
||||||
|
batch_size, _, latent_length, _, _ = latent_image.shape
|
||||||
|
noise_mask = torch.ones(
|
||||||
|
(batch_size, 1, latent_length, 1, 1),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=latent_image.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
noise_mask = noise_mask.clone()
|
||||||
|
return noise_mask
|
||||||
|
|
||||||
|
def get_keyframe_idxs(cond):
|
||||||
|
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
||||||
|
if keyframe_idxs is None:
|
||||||
|
return None, 0
|
||||||
|
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
|
||||||
|
return keyframe_idxs, num_keyframes
|
||||||
|
|
||||||
|
class LTXVAddGuide:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"latent": ("LATENT",),
|
||||||
|
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames."
|
||||||
|
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
|
||||||
|
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
|
||||||
|
"tooltip": "Frame index to start the conditioning at. For single-frame images or "
|
||||||
|
"videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
|
||||||
|
"frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
|
||||||
|
"the nearest multiple of 8. Negative values are counted from the end of the video."}),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._num_prefix_frames = 2
|
||||||
|
self._patchifier = SymmetricPatchifier(1)
|
||||||
|
|
||||||
|
def encode(self, vae, latent_width, latent_height, images, scale_factors):
|
||||||
|
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
|
||||||
|
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
|
||||||
|
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
|
||||||
|
encode_pixels = pixels[:, :, :, :3]
|
||||||
|
t = vae.encode(encode_pixels)
|
||||||
|
return encode_pixels, t
|
||||||
|
|
||||||
|
def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors):
|
||||||
|
time_scale_factor, _, _ = scale_factors
|
||||||
|
_, num_keyframes = get_keyframe_idxs(cond)
|
||||||
|
latent_count = latent_length - num_keyframes
|
||||||
|
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||||
|
if guide_length > 1:
|
||||||
|
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
|
||||||
|
|
||||||
|
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
|
||||||
|
|
||||||
|
return frame_idx, latent_idx
|
||||||
|
|
||||||
|
def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors):
|
||||||
|
keyframe_idxs, _ = get_keyframe_idxs(cond)
|
||||||
|
_, latent_coords = self._patchifier.patchify(guiding_latent)
|
||||||
|
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, True)
|
||||||
|
pixel_coords[:, 0] += frame_idx
|
||||||
|
if keyframe_idxs is None:
|
||||||
|
keyframe_idxs = pixel_coords
|
||||||
|
else:
|
||||||
|
keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
|
||||||
|
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
||||||
|
|
||||||
|
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
|
||||||
|
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
||||||
|
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||||
|
|
||||||
|
mask = torch.full(
|
||||||
|
(noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1),
|
||||||
|
1.0 - strength,
|
||||||
|
dtype=noise_mask.dtype,
|
||||||
|
device=noise_mask.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
|
||||||
|
noise_mask = torch.cat([noise_mask, mask], dim=2)
|
||||||
|
return positive, negative, latent_image, noise_mask
|
||||||
|
|
||||||
|
def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength):
|
||||||
|
cond_length = guiding_latent.shape[2]
|
||||||
|
assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence."
|
||||||
|
|
||||||
|
mask = torch.full(
|
||||||
|
(noise_mask.shape[0], 1, cond_length, 1, 1),
|
||||||
|
1.0 - strength,
|
||||||
|
dtype=noise_mask.dtype,
|
||||||
|
device=noise_mask.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_image = latent_image.clone()
|
||||||
|
noise_mask = noise_mask.clone()
|
||||||
|
|
||||||
|
latent_image[:, :, latent_idx : latent_idx + cond_length] = guiding_latent
|
||||||
|
noise_mask[:, :, latent_idx : latent_idx + cond_length] = mask
|
||||||
|
|
||||||
|
return latent_image, noise_mask
|
||||||
|
|
||||||
|
def generate(self, positive, negative, vae, latent, image, frame_idx, strength):
|
||||||
|
scale_factors = vae.downscale_index_formula
|
||||||
|
latent_image = latent["samples"]
|
||||||
|
noise_mask = get_noise_mask(latent)
|
||||||
|
|
||||||
|
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
||||||
|
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
|
||||||
|
|
||||||
|
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||||
|
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||||
|
|
||||||
|
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
|
||||||
|
|
||||||
|
positive, negative, latent_image, noise_mask = self.append_keyframe(
|
||||||
|
positive,
|
||||||
|
negative,
|
||||||
|
frame_idx,
|
||||||
|
latent_image,
|
||||||
|
noise_mask,
|
||||||
|
t[:, :, :num_prefix_frames],
|
||||||
|
strength,
|
||||||
|
scale_factors,
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_idx += num_prefix_frames
|
||||||
|
|
||||||
|
t = t[:, :, num_prefix_frames:]
|
||||||
|
if t.shape[2] == 0:
|
||||||
|
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||||
|
|
||||||
|
latent_image, noise_mask = self.replace_latent_frames(
|
||||||
|
latent_image,
|
||||||
|
noise_mask,
|
||||||
|
t,
|
||||||
|
latent_idx,
|
||||||
|
strength,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVCropGuides:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"latent": ("LATENT",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
FUNCTION = "crop"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._patchifier = SymmetricPatchifier(1)
|
||||||
|
|
||||||
|
def crop(self, positive, negative, latent):
|
||||||
|
latent_image = latent["samples"].clone()
|
||||||
|
noise_mask = get_noise_mask(latent)
|
||||||
|
|
||||||
|
_, num_keyframes = get_keyframe_idxs(positive)
|
||||||
|
if num_keyframes == 0:
|
||||||
|
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||||
|
|
||||||
|
latent_image = latent_image[:, :, :-num_keyframes]
|
||||||
|
noise_mask = noise_mask[:, :, :-num_keyframes]
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
|
||||||
|
|
||||||
|
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||||
|
|
||||||
|
|
||||||
class LTXVConditioning:
|
class LTXVConditioning:
|
||||||
@@ -174,6 +381,77 @@ class LTXVScheduler:
|
|||||||
|
|
||||||
return (sigmas,)
|
return (sigmas,)
|
||||||
|
|
||||||
|
def encode_single_frame(output_file, image_array: np.ndarray, crf):
|
||||||
|
container = av.open(output_file, "w", format="mp4")
|
||||||
|
try:
|
||||||
|
stream = container.add_stream(
|
||||||
|
"h264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
|
||||||
|
)
|
||||||
|
stream.height = image_array.shape[0]
|
||||||
|
stream.width = image_array.shape[1]
|
||||||
|
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
|
||||||
|
format="yuv420p"
|
||||||
|
)
|
||||||
|
container.mux(stream.encode(av_frame))
|
||||||
|
container.mux(stream.encode())
|
||||||
|
finally:
|
||||||
|
container.close()
|
||||||
|
|
||||||
|
|
||||||
|
def decode_single_frame(video_file):
|
||||||
|
container = av.open(video_file)
|
||||||
|
try:
|
||||||
|
stream = next(s for s in container.streams if s.type == "video")
|
||||||
|
frame = next(container.decode(stream))
|
||||||
|
finally:
|
||||||
|
container.close()
|
||||||
|
return frame.to_ndarray(format="rgb24")
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(image: torch.Tensor, crf=29):
|
||||||
|
if crf == 0:
|
||||||
|
return image
|
||||||
|
|
||||||
|
image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy()
|
||||||
|
with io.BytesIO() as output_file:
|
||||||
|
encode_single_frame(output_file, image_array, crf)
|
||||||
|
video_bytes = output_file.getvalue()
|
||||||
|
with io.BytesIO(video_bytes) as video_file:
|
||||||
|
image_array = decode_single_frame(video_file)
|
||||||
|
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVPreprocess:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"img_compression": (
|
||||||
|
"INT",
|
||||||
|
{
|
||||||
|
"default": 35,
|
||||||
|
"min": 0,
|
||||||
|
"max": 100,
|
||||||
|
"tooltip": "Amount of compression to apply on image.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FUNCTION = "preprocess"
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
RETURN_NAMES = ("output_image",)
|
||||||
|
CATEGORY = "image"
|
||||||
|
|
||||||
|
def preprocess(self, image, img_compression):
|
||||||
|
if img_compression > 0:
|
||||||
|
output_images = []
|
||||||
|
for i in range(image.shape[0]):
|
||||||
|
output_images.append(preprocess(image[i], img_compression))
|
||||||
|
return (torch.stack(output_images),)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
|
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
|
||||||
@@ -181,4 +459,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ModelSamplingLTXV": ModelSamplingLTXV,
|
"ModelSamplingLTXV": ModelSamplingLTXV,
|
||||||
"LTXVConditioning": LTXVConditioning,
|
"LTXVConditioning": LTXVConditioning,
|
||||||
"LTXVScheduler": LTXVScheduler,
|
"LTXVScheduler": LTXVScheduler,
|
||||||
|
"LTXVAddGuide": LTXVAddGuide,
|
||||||
|
"LTXVPreprocess": LTXVPreprocess,
|
||||||
|
"LTXVCropGuides": LTXVCropGuides,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import av
|
import av
|
||||||
import torch
|
import torch
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import json
|
import json
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
|
from comfy.comfy_types import FileLocator
|
||||||
|
|
||||||
|
|
||||||
class SaveWEBM:
|
class SaveWEBM:
|
||||||
@@ -25,15 +28,12 @@ class SaveWEBM:
|
|||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
RETURN_TYPES = ()
|
||||||
FUNCTION = "save_images"
|
FUNCTION = "save_video"
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
CATEGORY = "video"
|
||||||
CATEGORY = "image/video"
|
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None):
|
def save_video(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None):
|
||||||
filename_prefix += self.prefix_append
|
filename_prefix += self.prefix_append
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||||
|
|
||||||
@@ -62,13 +62,13 @@ class SaveWEBM:
|
|||||||
container.mux(stream.encode())
|
container.mux(stream.encode())
|
||||||
container.close()
|
container.close()
|
||||||
|
|
||||||
results = [{
|
results: list[FileLocator] = [{
|
||||||
"filename": file,
|
"filename": file,
|
||||||
"subfolder": subfolder,
|
"subfolder": subfolder,
|
||||||
"type": self.type
|
"type": self.type
|
||||||
}]
|
}]
|
||||||
|
|
||||||
return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side
|
return {"ui": {"video": results}}
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import comfy.utils
|
|||||||
import comfy.sd
|
import comfy.sd
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import comfy_extras.nodes_model_merging
|
import comfy_extras.nodes_model_merging
|
||||||
|
import node_helpers
|
||||||
|
|
||||||
|
|
||||||
class ImageOnlyCheckpointLoader:
|
class ImageOnlyCheckpointLoader:
|
||||||
@@ -121,12 +122,38 @@ class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
|
|||||||
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningSetAreaPercentageVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||||
|
"width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
|
||||||
|
"height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
|
||||||
|
"temporal": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
|
||||||
|
"x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
|
||||||
|
"y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
|
||||||
|
"z": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "append"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
|
def append(self, conditioning, width, height, temporal, x, y, z, strength):
|
||||||
|
c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", temporal, height, width, z, y, x),
|
||||||
|
"strength": strength,
|
||||||
|
"set_area_to_bounds": False})
|
||||||
|
return (c, )
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
|
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
|
||||||
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
|
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
|
||||||
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
|
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
|
||||||
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
|
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
|
||||||
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
|
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
|
||||||
|
"ConditioningSetAreaPercentageVideo": ConditioningSetAreaPercentageVideo,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.18"
|
__version__ = "0.3.26"
|
||||||
|
|||||||
@@ -634,6 +634,13 @@ def validate_inputs(prompt, item, validated):
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
# Unwraps values wrapped in __value__ key. This is used to pass
|
||||||
|
# list widget value to execution, as by default list value is
|
||||||
|
# reserved to represent the connection between nodes.
|
||||||
|
if isinstance(val, dict) and "__value__" in val:
|
||||||
|
val = val["__value__"]
|
||||||
|
inputs[x] = val
|
||||||
|
|
||||||
if type_input == "INT":
|
if type_input == "INT":
|
||||||
val = int(val)
|
val = int(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
|
|||||||
6
main.py
6
main.py
@@ -139,6 +139,7 @@ from server import BinaryEventTypes
|
|||||||
import nodes
|
import nodes
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfyui_version
|
import comfyui_version
|
||||||
|
import app.logger
|
||||||
|
|
||||||
|
|
||||||
def cuda_malloc_warning():
|
def cuda_malloc_warning():
|
||||||
@@ -295,9 +296,12 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Running directly, just start ComfyUI.
|
# Running directly, just start ComfyUI.
|
||||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||||
|
|
||||||
event_loop, _, start_all_func = start_comfyui()
|
event_loop, _, start_all_func = start_comfyui()
|
||||||
try:
|
try:
|
||||||
event_loop.run_until_complete(start_all_func())
|
x = start_all_func()
|
||||||
|
app.logger.print_startup_warnings()
|
||||||
|
event_loop.run_until_complete(x)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging.info("\nStopped server")
|
logging.info("\nStopped server")
|
||||||
|
|
||||||
|
|||||||
20
nodes.py
20
nodes.py
@@ -25,7 +25,7 @@ import comfy.sample
|
|||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.controlnet
|
import comfy.controlnet
|
||||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
||||||
|
|
||||||
import comfy.clip_vision
|
import comfy.clip_vision
|
||||||
|
|
||||||
@@ -479,7 +479,7 @@ class SaveLatent:
|
|||||||
|
|
||||||
file = f"{filename}_{counter:05}_.latent"
|
file = f"{filename}_{counter:05}_.latent"
|
||||||
|
|
||||||
results = list()
|
results: list[FileLocator] = []
|
||||||
results.append({
|
results.append({
|
||||||
"filename": file,
|
"filename": file,
|
||||||
"subfolder": subfolder,
|
"subfolder": subfolder,
|
||||||
@@ -489,7 +489,7 @@ class SaveLatent:
|
|||||||
file = os.path.join(full_output_folder, file)
|
file = os.path.join(full_output_folder, file)
|
||||||
|
|
||||||
output = {}
|
output = {}
|
||||||
output["latent_tensor"] = samples["samples"]
|
output["latent_tensor"] = samples["samples"].contiguous()
|
||||||
output["latent_format_version_0"] = torch.tensor([])
|
output["latent_format_version_0"] = torch.tensor([])
|
||||||
|
|
||||||
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
||||||
@@ -770,6 +770,7 @@ class VAELoader:
|
|||||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
sd = comfy.utils.load_torch_file(vae_path)
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
vae = comfy.sd.VAE(sd=sd)
|
vae = comfy.sd.VAE(sd=sd)
|
||||||
|
vae.throw_exception_if_invalid()
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
||||||
class ControlNetLoader:
|
class ControlNetLoader:
|
||||||
@@ -1519,7 +1520,7 @@ class KSampler:
|
|||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}),
|
"model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}),
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "tooltip": "The random seed used for creating the noise."}),
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
|
||||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}),
|
"steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}),
|
||||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}),
|
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}),
|
||||||
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}),
|
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}),
|
||||||
@@ -1547,7 +1548,7 @@ class KSamplerAdvanced:
|
|||||||
return {"required":
|
return {"required":
|
||||||
{"model": ("MODEL",),
|
{"model": ("MODEL",),
|
||||||
"add_noise": (["enable", "disable"], ),
|
"add_noise": (["enable", "disable"], ),
|
||||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}),
|
||||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||||
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
|
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
|
||||||
@@ -1785,14 +1786,7 @@ class LoadImageOutput(LoadImage):
|
|||||||
|
|
||||||
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
|
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
FUNCTION = "load_image_output"
|
FUNCTION = "load_image"
|
||||||
|
|
||||||
def load_image_output(self, image):
|
|
||||||
return self.load_image(f"{image} [output]")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def VALIDATE_INPUTS(s, image):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class ImageScale:
|
class ImageScale:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.18"
|
version = "0.3.26"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.10.17
|
comfyui-frontend-package==1.12.14
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ def test_get_release_invalid_version(mock_provider):
|
|||||||
def test_init_frontend_default():
|
def test_init_frontend_default():
|
||||||
version_string = DEFAULT_VERSION_STRING
|
version_string = DEFAULT_VERSION_STRING
|
||||||
frontend_path = FrontendManager.init_frontend(version_string)
|
frontend_path = FrontendManager.init_frontend(version_string)
|
||||||
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
|
assert frontend_path == FrontendManager.default_frontend_path()
|
||||||
|
|
||||||
|
|
||||||
def test_init_frontend_invalid_version():
|
def test_init_frontend_invalid_version():
|
||||||
@@ -84,24 +84,29 @@ def test_init_frontend_invalid_provider():
|
|||||||
with pytest.raises(HTTPError):
|
with pytest.raises(HTTPError):
|
||||||
FrontendManager.init_frontend_unsafe(version_string)
|
FrontendManager.init_frontend_unsafe(version_string)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_os_functions():
|
def mock_os_functions():
|
||||||
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
|
with (
|
||||||
patch('app.frontend_management.os.listdir') as mock_listdir, \
|
patch("app.frontend_management.os.makedirs") as mock_makedirs,
|
||||||
patch('app.frontend_management.os.rmdir') as mock_rmdir:
|
patch("app.frontend_management.os.listdir") as mock_listdir,
|
||||||
|
patch("app.frontend_management.os.rmdir") as mock_rmdir,
|
||||||
|
):
|
||||||
mock_listdir.return_value = [] # Simulate empty directory
|
mock_listdir.return_value = [] # Simulate empty directory
|
||||||
yield mock_makedirs, mock_listdir, mock_rmdir
|
yield mock_makedirs, mock_listdir, mock_rmdir
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_download():
|
def mock_download():
|
||||||
with patch('app.frontend_management.download_release_asset_zip') as mock:
|
with patch("app.frontend_management.download_release_asset_zip") as mock:
|
||||||
mock.side_effect = Exception("Download failed") # Simulate download failure
|
mock.side_effect = Exception("Download failed") # Simulate download failure
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
def test_finally_block(mock_os_functions, mock_download, mock_provider):
|
def test_finally_block(mock_os_functions, mock_download, mock_provider):
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
|
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
|
||||||
version_string = 'test-owner/test-repo@1.0.0'
|
version_string = "test-owner/test-repo@1.0.0"
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
@@ -128,3 +133,42 @@ def test_parse_version_string_invalid():
|
|||||||
version_string = "invalid"
|
version_string = "invalid"
|
||||||
with pytest.raises(argparse.ArgumentTypeError):
|
with pytest.raises(argparse.ArgumentTypeError):
|
||||||
FrontendManager.parse_version_string(version_string)
|
FrontendManager.parse_version_string(version_string)
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_frontend_default_with_mocks():
|
||||||
|
# Arrange
|
||||||
|
version_string = DEFAULT_VERSION_STRING
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with (
|
||||||
|
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||||
|
patch.object(
|
||||||
|
FrontendManager, "default_frontend_path", return_value="/mocked/path"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
frontend_path = FrontendManager.init_frontend(version_string)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert frontend_path == "/mocked/path"
|
||||||
|
mock_check.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_frontend_fallback_on_error():
|
||||||
|
# Arrange
|
||||||
|
version_string = "test-owner/test-repo@1.0.0"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
FrontendManager, "init_frontend_unsafe", side_effect=Exception("Test error")
|
||||||
|
),
|
||||||
|
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||||
|
patch.object(
|
||||||
|
FrontendManager, "default_frontend_path", return_value="/default/path"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
frontend_path = FrontendManager.init_frontend(version_string)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert frontend_path == "/default/path"
|
||||||
|
mock_check.assert_called_once()
|
||||||
|
|||||||
Reference in New Issue
Block a user