Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eaba79602f | ||
|
|
35504e2f93 | ||
|
|
299436cfed | ||
|
|
52e566d2bc | ||
|
|
9b6cd9b874 | ||
|
|
3fc688aebd | ||
|
|
f4411250f3 | ||
|
|
d2a0fb6bb0 | ||
|
|
01015bff16 | ||
|
|
2330754b0e | ||
|
|
bc219a6487 | ||
|
|
94689766ad | ||
|
|
cfbe4b49ca | ||
|
|
ca8efab79f | ||
|
|
65ea778a5e | ||
|
|
db9f2a34fc | ||
|
|
7946049794 | ||
|
|
6f6349b6a7 | ||
|
|
1f138dd382 | ||
|
|
b779349b55 | ||
|
|
35e2dcf5d7 | ||
|
|
67c7184b74 | ||
|
|
6f8e766509 | ||
|
|
e1da98a14a | ||
|
|
a73410aafa | ||
|
|
9aac21f894 | ||
|
|
528d1b3563 | ||
|
|
2bc4b5968f | ||
|
|
7395b0c0d1 | ||
|
|
0952569493 | ||
|
|
29832b3b61 | ||
|
|
be4e760648 | ||
|
|
c3d9cc4592 | ||
|
|
84cc9cb528 | ||
|
|
ebbb920163 | ||
|
|
d60fe0af4a | ||
|
|
5dbd250965 | ||
|
|
4ab1875283 | ||
|
|
11b1f27cb1 | ||
|
|
70e15fd743 | ||
|
|
e1474150de | ||
|
|
e62d72e8ca | ||
|
|
1650cda030 |
@@ -0,0 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||
pause
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "126"
|
||||
default: "128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "1"
|
||||
default: "2"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -34,7 +34,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
fetch-depth: 30
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -74,7 +74,7 @@ jobs:
|
||||
pause" > ./update/update_comfyui_and_python_dependencies.bat
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
|
||||
|
||||
cd ComfyUI_windows_portable_nightly_pytorch
|
||||
|
||||
@@ -19,5 +19,6 @@
|
||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
|
||||
# Extra nodes
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||
|
||||
@@ -215,9 +215,9 @@ Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||
This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements.
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
|
||||
@@ -11,20 +11,43 @@ from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict, Optional
|
||||
from importlib.metadata import version
|
||||
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
import app.logger
|
||||
|
||||
# The path to the requirements.txt file
|
||||
req_path = Path(__file__).parents[1] / "requirements.txt"
|
||||
|
||||
def frontend_install_warning_message():
|
||||
"""The warning message to display when the frontend version is not up to date."""
|
||||
|
||||
extra = ""
|
||||
if sys.flags.no_user_site:
|
||||
extra = "-s "
|
||||
return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem"
|
||||
|
||||
|
||||
try:
|
||||
import comfyui_frontend_package
|
||||
except ImportError:
|
||||
# TODO: Remove the check after roll out of 0.3.16
|
||||
req_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))
|
||||
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. Please install the updated requirements.txt file by running:\n{sys.executable} -m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n")
|
||||
exit(-1)
|
||||
def check_frontend_version():
|
||||
"""Check if the frontend version is up to date."""
|
||||
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
return tuple(map(int, version.split(".")))
|
||||
|
||||
try:
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
frontend_version = parse_version(frontend_version_str)
|
||||
with open(req_path, "r", encoding="utf-8") as f:
|
||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||
if frontend_version < required_frontend:
|
||||
app.logger.log_startup_warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message()))
|
||||
else:
|
||||
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check frontend version: {e}")
|
||||
|
||||
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
@@ -121,9 +144,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
|
||||
|
||||
class FrontendManager:
|
||||
DEFAULT_FRONTEND_PATH = str(importlib.resources.files(comfyui_frontend_package) / "static")
|
||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
try:
|
||||
import comfyui_frontend_package
|
||||
return str(importlib.resources.files(comfyui_frontend_package) / "static")
|
||||
except ImportError:
|
||||
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n")
|
||||
sys.exit(-1)
|
||||
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
@@ -160,7 +191,8 @@ class FrontendManager:
|
||||
main error source might be request timeout or invalid URL.
|
||||
"""
|
||||
if version_string == DEFAULT_VERSION_STRING:
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
|
||||
@@ -213,4 +245,5 @@ class FrontendManager:
|
||||
except Exception as e:
|
||||
logging.error("Failed to initialize frontend: %s", e)
|
||||
logging.info("Falling back to the default frontend.")
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
|
||||
@@ -82,3 +82,17 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
|
||||
logger.addHandler(stdout_handler)
|
||||
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
STARTUP_WARNINGS = []
|
||||
|
||||
|
||||
def log_startup_warning(msg):
|
||||
logging.warning(msg)
|
||||
STARTUP_WARNINGS.append(msg)
|
||||
|
||||
|
||||
def print_startup_warnings():
|
||||
for s in STARTUP_WARNINGS:
|
||||
logging.warning(s)
|
||||
STARTUP_WARNINGS.clear()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Literal, TypedDict
|
||||
from typing_extensions import NotRequired
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
@@ -26,6 +27,7 @@ class IO(StrEnum):
|
||||
BOOLEAN = "BOOLEAN"
|
||||
INT = "INT"
|
||||
FLOAT = "FLOAT"
|
||||
COMBO = "COMBO"
|
||||
CONDITIONING = "CONDITIONING"
|
||||
SAMPLER = "SAMPLER"
|
||||
SIGMAS = "SIGMAS"
|
||||
@@ -66,6 +68,7 @@ class IO(StrEnum):
|
||||
b = frozenset(value.split(","))
|
||||
return not (b.issubset(a) or a.issubset(b))
|
||||
|
||||
|
||||
class RemoteInputOptions(TypedDict):
|
||||
route: str
|
||||
"""The route to the remote source."""
|
||||
@@ -80,6 +83,14 @@ class RemoteInputOptions(TypedDict):
|
||||
refresh: int
|
||||
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
|
||||
|
||||
|
||||
class MultiSelectOptions(TypedDict):
|
||||
placeholder: NotRequired[str]
|
||||
"""The placeholder text to display in the multi-select widget when no items are selected."""
|
||||
chip: NotRequired[bool]
|
||||
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
|
||||
|
||||
|
||||
class InputTypeOptions(TypedDict):
|
||||
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
||||
|
||||
@@ -114,7 +125,7 @@ class InputTypeOptions(TypedDict):
|
||||
# default: bool
|
||||
label_on: str
|
||||
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
||||
label_on: str
|
||||
label_off: str
|
||||
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
||||
# class InputTypeString(InputTypeOptions):
|
||||
# default: str
|
||||
@@ -133,9 +144,22 @@ class InputTypeOptions(TypedDict):
|
||||
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
||||
"""
|
||||
remote: RemoteInputOptions
|
||||
"""Specifies the configuration for a remote input."""
|
||||
"""Specifies the configuration for a remote input.
|
||||
Available after ComfyUI frontend v1.9.7
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
|
||||
control_after_generate: bool
|
||||
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
|
||||
options: NotRequired[list[str | int | float]]
|
||||
"""COMBO type only. Specifies the selectable options for the combo widget.
|
||||
Prefer:
|
||||
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
|
||||
Over:
|
||||
[["Option 1", "Option 2", "Option 3"]]
|
||||
"""
|
||||
multi_select: NotRequired[MultiSelectOptions]
|
||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||
Available after ComfyUI frontend v1.13.4
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
|
||||
@@ -688,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
@@ -762,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
if solver_type not in {'heun', 'midpoint'}:
|
||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
old_denoised = None
|
||||
@@ -808,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
denoised_1, denoised_2 = None, None
|
||||
@@ -858,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||
@@ -867,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
@@ -876,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
||||
@@ -1366,3 +1366,59 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
|
||||
x = x + d_bar * dt
|
||||
old_d = d
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""
|
||||
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
|
||||
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
def default_noise_scaler(sigma):
|
||||
return sigma * ((sigma ** 0.3).exp() + 10.0)
|
||||
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
|
||||
num_integration_points = 200.0
|
||||
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
||||
|
||||
old_denoised = None
|
||||
old_denoised_d = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
stage_used = min(max_stage, i + 1)
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
elif stage_used == 1:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
else:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
sigma_step_size = -dt / num_integration_points
|
||||
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
|
||||
scaled_pos = noise_scaler(sigma_pos)
|
||||
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * sigma_step_size
|
||||
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
|
||||
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
|
||||
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
|
||||
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
|
||||
if s_noise != 0 and sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt()
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
@@ -19,6 +19,10 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class vector_quantize(Function):
|
||||
@staticmethod
|
||||
@@ -121,15 +125,15 @@ class ResBlock(nn.Module):
|
||||
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.depthwise = nn.Sequential(
|
||||
nn.ReplicationPad2d(1),
|
||||
nn.Conv2d(c, c, kernel_size=3, groups=c)
|
||||
ops.Conv2d(c, c, kernel_size=3, groups=c)
|
||||
)
|
||||
|
||||
# channelwise
|
||||
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.channelwise = nn.Sequential(
|
||||
nn.Linear(c, c_hidden),
|
||||
ops.Linear(c, c_hidden),
|
||||
nn.GELU(),
|
||||
nn.Linear(c_hidden, c),
|
||||
ops.Linear(c_hidden, c),
|
||||
)
|
||||
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||
@@ -171,16 +175,16 @@ class StageA(nn.Module):
|
||||
# Encoder blocks
|
||||
self.in_block = nn.Sequential(
|
||||
nn.PixelUnshuffle(2),
|
||||
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||
ops.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||
)
|
||||
down_blocks = []
|
||||
for i in range(levels):
|
||||
if i > 0:
|
||||
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||
down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
||||
down_blocks.append(block)
|
||||
down_blocks.append(nn.Sequential(
|
||||
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||
ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
||||
))
|
||||
self.down_blocks = nn.Sequential(*down_blocks)
|
||||
@@ -191,7 +195,7 @@ class StageA(nn.Module):
|
||||
|
||||
# Decoder blocks
|
||||
up_blocks = [nn.Sequential(
|
||||
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||
ops.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||
)]
|
||||
for i in range(levels):
|
||||
for j in range(bottleneck_blocks if i == 0 else 1):
|
||||
@@ -199,11 +203,11 @@ class StageA(nn.Module):
|
||||
up_blocks.append(block)
|
||||
if i < levels - 1:
|
||||
up_blocks.append(
|
||||
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||
ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||
padding=1))
|
||||
self.up_blocks = nn.Sequential(*up_blocks)
|
||||
self.out_block = nn.Sequential(
|
||||
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||
ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||
nn.PixelShuffle(2),
|
||||
)
|
||||
|
||||
@@ -232,17 +236,17 @@ class Discriminator(nn.Module):
|
||||
super().__init__()
|
||||
d = max(depth - 3, 3)
|
||||
layers = [
|
||||
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||
nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||
nn.LeakyReLU(0.2),
|
||||
]
|
||||
for i in range(depth - 1):
|
||||
c_in = c_hidden // (2 ** max((d - i), 0))
|
||||
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
||||
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||
layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||
layers.append(nn.InstanceNorm2d(c_out))
|
||||
layers.append(nn.LeakyReLU(0.2))
|
||||
self.encoder = nn.Sequential(*layers)
|
||||
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||
self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||
self.logits = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
|
||||
@@ -19,6 +19,9 @@ import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
# EfficientNet
|
||||
class EfficientNetEncoder(nn.Module):
|
||||
@@ -26,7 +29,7 @@ class EfficientNetEncoder(nn.Module):
|
||||
super().__init__()
|
||||
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
||||
self.mapper = nn.Sequential(
|
||||
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||
ops.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
||||
)
|
||||
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
|
||||
@@ -34,7 +37,7 @@ class EfficientNetEncoder(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 0.5 + 0.5
|
||||
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
|
||||
x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype)
|
||||
o = self.mapper(self.backbone(x))
|
||||
return o
|
||||
|
||||
@@ -44,39 +47,39 @@ class Previewer(nn.Module):
|
||||
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
||||
super().__init__()
|
||||
self.blocks = nn.Sequential(
|
||||
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||
ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||
ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||
ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||
ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||
ops.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@@ -105,7 +105,9 @@ class Modulation(nn.Module):
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
if vec.ndim == 2:
|
||||
vec = vec[:, None, :]
|
||||
out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
@@ -113,6 +115,20 @@ class Modulation(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
if modulation_dims is None:
|
||||
if m_add is not None:
|
||||
return tensor * m_mult + m_add
|
||||
else:
|
||||
return tensor * m_mult
|
||||
else:
|
||||
for d in modulation_dims:
|
||||
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
|
||||
if m_add is not None:
|
||||
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
|
||||
return tensor
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@@ -143,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
@@ -179,12 +195,12 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
@@ -228,9 +244,9 @@ class SingleStreamBlock(nn.Module):
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
@@ -239,7 +255,7 @@ class SingleStreamBlock(nn.Module):
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
@@ -252,8 +268,11 @@ class LastLayer(nn.Module):
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
|
||||
if vec.ndim == 2:
|
||||
vec = vec[:, None, :]
|
||||
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
|
||||
x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
@@ -227,6 +227,7 @@ class HunyuanVideo(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
@@ -237,7 +238,17 @@ class HunyuanVideo(nn.Module):
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
if guiding_frame_index is not None:
|
||||
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
||||
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
||||
modulation_dims_txt = [(0, None, 1)]
|
||||
else:
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
modulation_dims = None
|
||||
modulation_dims_txt = None
|
||||
|
||||
if self.params.guidance_embed:
|
||||
if guidance is not None:
|
||||
@@ -264,14 +275,14 @@ class HunyuanVideo(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@@ -286,13 +297,13 @@ class HunyuanVideo(nn.Module):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@@ -303,7 +314,7 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
img = img[:, : img_len]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
shape = initial_shape[-3:]
|
||||
for i in range(len(shape)):
|
||||
@@ -313,7 +324,7 @@ class HunyuanVideo(nn.Module):
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
@@ -325,5 +336,5 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
|
||||
return out
|
||||
|
||||
@@ -108,7 +108,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
|
||||
fp8 = model_config.optimizations.get("fp8", False)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
@@ -898,20 +898,31 @@ class HunyuanVideo(BaseModel):
|
||||
guidance = kwargs.get("guidance", 6.0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
|
||||
guiding_frame_index = kwargs.get("guiding_frame_index", None)
|
||||
if guiding_frame_index is not None:
|
||||
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
|
||||
|
||||
return out
|
||||
|
||||
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class HunyuanVideoI2V(HunyuanVideo):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.concat_keys = ("concat_image", "mask_inverted")
|
||||
|
||||
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
|
||||
|
||||
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.concat_keys = ("concat_image",)
|
||||
|
||||
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
|
||||
|
||||
class CosmosVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
|
||||
@@ -962,11 +973,11 @@ class WAN21(BaseModel):
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
if not self.image_to_video:
|
||||
noise = kwargs.get("noise", None)
|
||||
if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
@@ -976,6 +987,9 @@ class WAN21(BaseModel):
|
||||
image = self.process_latent_in(image)
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
if not self.image_to_video:
|
||||
return image
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(noise)[:, :4]
|
||||
|
||||
@@ -471,6 +471,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||
if model_config.scaled_fp8 == torch.float32:
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
if scaled_fp8_weight.nelement() == 2:
|
||||
model_config.optimizations["fp8"] = False
|
||||
else:
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
@@ -186,12 +186,21 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
else:
|
||||
return mem_total
|
||||
|
||||
def mac_version():
|
||||
try:
|
||||
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||
except:
|
||||
return None
|
||||
|
||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
|
||||
try:
|
||||
logging.info("pytorch version: {}".format(torch_version))
|
||||
mac_ver = mac_version()
|
||||
if mac_ver is not None:
|
||||
logging.info("Mac Version {}".format(mac_ver))
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -581,7 +590,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
loaded_memory = loaded_model.model_loaded_memory()
|
||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||
|
||||
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
@@ -969,12 +978,6 @@ def pytorch_attention_flash_attention():
|
||||
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
||||
return False
|
||||
|
||||
def mac_version():
|
||||
try:
|
||||
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||
except:
|
||||
return None
|
||||
|
||||
def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
|
||||
|
||||
@@ -1089,7 +1089,6 @@ class ModelPatcher:
|
||||
|
||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
if hooks is not None:
|
||||
model_sd_keys = list(self.model_state_dict().keys())
|
||||
memory_counter = None
|
||||
@@ -1100,12 +1099,16 @@ class ModelPatcher:
|
||||
# if have cached weights for hooks, use it
|
||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||
if cached_weights is not None:
|
||||
model_sd_keys_set = set(model_sd_keys)
|
||||
for key in cached_weights:
|
||||
if key not in model_sd_keys:
|
||||
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||
model_sd_keys_set.remove(key)
|
||||
self.unpatch_hooks(model_sd_keys_set)
|
||||
else:
|
||||
self.unpatch_hooks()
|
||||
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
||||
original_weights = None
|
||||
if len(relevant_patches) > 0:
|
||||
@@ -1116,6 +1119,8 @@ class ModelPatcher:
|
||||
continue
|
||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||
memory_counter=memory_counter)
|
||||
else:
|
||||
self.unpatch_hooks()
|
||||
self.current_hooks = hooks
|
||||
|
||||
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||
@@ -1172,17 +1177,23 @@ class ModelPatcher:
|
||||
del out_weight
|
||||
del weight
|
||||
|
||||
def unpatch_hooks(self) -> None:
|
||||
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||
with self.use_ejected():
|
||||
if len(self.hook_backup) == 0:
|
||||
self.current_hooks = None
|
||||
return
|
||||
keys = list(self.hook_backup.keys())
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
if whitelist_keys_set:
|
||||
for k in keys:
|
||||
if k in whitelist_keys_set:
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
self.hook_backup.pop(k)
|
||||
else:
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
|
||||
self.hook_backup.clear()
|
||||
self.current_hooks = None
|
||||
self.hook_backup.clear()
|
||||
self.current_hooks = None
|
||||
|
||||
def clean_hooks(self):
|
||||
self.unpatch_hooks()
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
@@ -308,6 +309,7 @@ class fp8_ops(manual_cast):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
||||
class scaled_fp8_op(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -358,7 +360,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||
|
||||
if (
|
||||
fp8_compute and
|
||||
|
||||
@@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||
"gradient_estimation"]
|
||||
"gradient_estimation", "er_sde"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
|
||||
@@ -931,7 +931,7 @@ class WAN21_T2V(supported_models_base.BASE):
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
@@ -42,7 +42,7 @@ class HunyuanVideoTokenizer:
|
||||
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
|
||||
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, **kwargs):
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
@@ -56,7 +56,7 @@ class HunyuanVideoTokenizer:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 128257:
|
||||
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
||||
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image", "image_interleave": image_interleave},) + r[i][1:]
|
||||
embed_count += 1
|
||||
out["llama"] = llama_text_tokens
|
||||
return out
|
||||
@@ -92,10 +92,10 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
||||
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||
|
||||
template_end = 0
|
||||
image_start = None
|
||||
image_end = None
|
||||
extra_template_end = 0
|
||||
extra_sizes = 0
|
||||
user_end = 9999999999999
|
||||
images = []
|
||||
|
||||
tok_pairs = token_weight_pairs_llama[0]
|
||||
for i, v in enumerate(tok_pairs):
|
||||
@@ -112,22 +112,28 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
||||
else:
|
||||
if elem.get("original_type") == "image":
|
||||
elem_size = elem.get("data").shape[0]
|
||||
if image_start is None:
|
||||
if template_end > 0:
|
||||
if user_end == -1:
|
||||
extra_template_end += elem_size - 1
|
||||
else:
|
||||
image_start = i + extra_sizes
|
||||
image_end = i + elem_size + extra_sizes
|
||||
extra_sizes += elem_size - 1
|
||||
images.append((image_start, image_end, elem.get("image_interleave", 1)))
|
||||
extra_sizes += elem_size - 1
|
||||
|
||||
if llama_out.shape[1] > (template_end + 2):
|
||||
if tok_pairs[template_end + 1][0] == 271:
|
||||
template_end += 2
|
||||
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes]
|
||||
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes]
|
||||
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
||||
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
||||
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
|
||||
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
|
||||
|
||||
if image_start is not None:
|
||||
image_output = llama_out[:, image_start: image_end]
|
||||
llama_output = torch.cat([image_output[:, ::2], llama_output], dim=1)
|
||||
if len(images) > 0:
|
||||
out = []
|
||||
for i in images:
|
||||
out.append(llama_out[:, i[0]: i[1]: i[2]])
|
||||
llama_output = torch.cat(out + [llama_output], dim=1)
|
||||
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
return llama_output, l_pooled, llama_extra_out
|
||||
|
||||
@@ -57,17 +57,17 @@ class TextEncodeHunyuanVideo_ImageToVideo:
|
||||
"clip": ("CLIP", ),
|
||||
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def encode(self, clip, clip_vision_output, prompt):
|
||||
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected)
|
||||
def encode(self, clip, clip_vision_output, prompt, image_interleave):
|
||||
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
|
||||
class HunyuanImageToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -77,6 +77,7 @@ class HunyuanImageToVideo:
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
|
||||
},
|
||||
"optional": {"start_image": ("IMAGE", ),
|
||||
}}
|
||||
@@ -87,8 +88,10 @@ class HunyuanImageToVideo:
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, vae, width, height, length, batch_size, start_image=None):
|
||||
def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
out_latent = {}
|
||||
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
|
||||
@@ -96,13 +99,20 @@ class HunyuanImageToVideo:
|
||||
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
if guidance_type == "v1 (concat)":
|
||||
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
|
||||
else:
|
||||
cond = {'guiding_frame_index': 0}
|
||||
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
|
||||
out_latent["noise_mask"] = mask
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, cond)
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, out_latent)
|
||||
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
||||
|
||||
@@ -19,8 +19,6 @@ class Load3D():
|
||||
"image": ("LOAD_3D", {}),
|
||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||
@@ -55,8 +53,6 @@ class Load3DAnimation():
|
||||
"image": ("LOAD_3D_ANIMATION", {}),
|
||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
||||
@@ -82,8 +78,6 @@ class Preview3D():
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
@@ -102,8 +96,6 @@ class Preview3DAnimation():
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
"material": (["original", "normal", "wireframe", "depth"],),
|
||||
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@@ -99,12 +99,13 @@ class LTXVAddGuide:
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE",),
|
||||
"latent": ("LATENT",),
|
||||
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \
|
||||
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames."
|
||||
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
|
||||
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
|
||||
"tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \
|
||||
"If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \
|
||||
"Negative values are counted from the end of the video."}),
|
||||
"tooltip": "Frame index to start the conditioning at. For single-frame images or "
|
||||
"videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
|
||||
"frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
|
||||
"the nearest multiple of 8. Negative values are counted from the end of the video."}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
@@ -127,12 +128,13 @@ class LTXVAddGuide:
|
||||
t = vae.encode(encode_pixels)
|
||||
return encode_pixels, t
|
||||
|
||||
def get_latent_index(self, cond, latent_length, frame_idx, scale_factors):
|
||||
def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors):
|
||||
time_scale_factor, _, _ = scale_factors
|
||||
_, num_keyframes = get_keyframe_idxs(cond)
|
||||
latent_count = latent_length - num_keyframes
|
||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0)
|
||||
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
|
||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||
if guide_length > 1:
|
||||
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
|
||||
|
||||
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
|
||||
|
||||
@@ -191,7 +193,7 @@ class LTXVAddGuide:
|
||||
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
||||
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
|
||||
|
||||
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors)
|
||||
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.24"
|
||||
__version__ = "0.3.26"
|
||||
|
||||
@@ -634,6 +634,13 @@ def validate_inputs(prompt, item, validated):
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
# Unwraps values wrapped in __value__ key. This is used to pass
|
||||
# list widget value to execution, as by default list value is
|
||||
# reserved to represent the connection between nodes.
|
||||
if isinstance(val, dict) and "__value__" in val:
|
||||
val = val["__value__"]
|
||||
inputs[x] = val
|
||||
|
||||
if type_input == "INT":
|
||||
val = int(val)
|
||||
inputs[x] = val
|
||||
|
||||
6
main.py
6
main.py
@@ -139,6 +139,7 @@ from server import BinaryEventTypes
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
import app.logger
|
||||
|
||||
|
||||
def cuda_malloc_warning():
|
||||
@@ -295,9 +296,12 @@ def start_comfyui(asyncio_loop=None):
|
||||
if __name__ == "__main__":
|
||||
# Running directly, just start ComfyUI.
|
||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||
|
||||
event_loop, _, start_all_func = start_comfyui()
|
||||
try:
|
||||
event_loop.run_until_complete(start_all_func())
|
||||
x = start_all_func()
|
||||
app.logger.print_startup_warnings()
|
||||
event_loop.run_until_complete(x)
|
||||
except KeyboardInterrupt:
|
||||
logging.info("\nStopped server")
|
||||
|
||||
|
||||
11
nodes.py
11
nodes.py
@@ -489,7 +489,7 @@ class SaveLatent:
|
||||
file = os.path.join(full_output_folder, file)
|
||||
|
||||
output = {}
|
||||
output["latent_tensor"] = samples["samples"]
|
||||
output["latent_tensor"] = samples["samples"].contiguous()
|
||||
output["latent_format_version_0"] = torch.tensor([])
|
||||
|
||||
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
||||
@@ -1785,14 +1785,7 @@ class LoadImageOutput(LoadImage):
|
||||
|
||||
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
|
||||
EXPERIMENTAL = True
|
||||
FUNCTION = "load_image_output"
|
||||
|
||||
def load_image_output(self, image):
|
||||
return self.load_image(f"{image} [output]")
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, image):
|
||||
return True
|
||||
FUNCTION = "load_image"
|
||||
|
||||
|
||||
class ImageScale:
|
||||
|
||||
130
pyproject.toml
130
pyproject.toml
@@ -1,9 +1,79 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.24"
|
||||
version = "0.3.26"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
dependencies = [
|
||||
"torchsde",
|
||||
"numpy>=1.25.0",
|
||||
"einops",
|
||||
"transformers>=4.28.1",
|
||||
"tokenizers>=0.13.3",
|
||||
"sentencepiece",
|
||||
"safetensors>=0.4.2",
|
||||
"aiohttp",
|
||||
"pyyaml",
|
||||
"Pillow",
|
||||
"scipy",
|
||||
"tqdm",
|
||||
"psutil",
|
||||
# Optional dependencies
|
||||
"kornia>=0.7.1",
|
||||
"spandrel",
|
||||
"soundfile",
|
||||
"comfyui-manager",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
cpu = [
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
]
|
||||
|
||||
cu126 = [
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
]
|
||||
|
||||
cu124 = [
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
]
|
||||
|
||||
cu118 = [
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
]
|
||||
|
||||
rocm = [
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
]
|
||||
|
||||
xpus = [
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "cpu" },
|
||||
{ extra = "cu126" },
|
||||
{ extra = "cu124" },
|
||||
{ extra = "cu118" },
|
||||
{ extra = "rocm" },
|
||||
{ extra = "xpus" },
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
[project.urls]
|
||||
homepage = "https://www.comfy.org/"
|
||||
@@ -21,3 +91,61 @@ lint.select = [
|
||||
"F",
|
||||
]
|
||||
exclude = ["*.ipynb"]
|
||||
|
||||
|
||||
[tool.uv.sources]
|
||||
comfyui-manager = { path = "custom_nodes/ComfyUI-Manager" }
|
||||
torch = [
|
||||
{ index = "pytorch-cu126", extra = "cu126" },
|
||||
{ index = "pytorch-cu124", extra = "cu124" },
|
||||
{ index = "pytorch-cu118", extra = "cu118" },
|
||||
{ index = "pytorch-cpu", extra = "cpu" },
|
||||
{ index = "pytorch-rocm", extra = "rocm" },
|
||||
{ index = "pytorch-xpu", extra = "xpus" },
|
||||
]
|
||||
torchvision = [
|
||||
{ index = "pytorch-cu126", extra = "cu126" },
|
||||
{ index = "pytorch-cu124", extra = "cu124" },
|
||||
{ index = "pytorch-cu118", extra = "cu118" },
|
||||
{ index = "pytorch-cpu", extra = "cpu" },
|
||||
{ index = "pytorch-rocm", extra = "rocm" },
|
||||
{ index = "pytorch-xpu", extra = "xpus" },
|
||||
]
|
||||
torchaudio = [
|
||||
{ index = "pytorch-cu126", extra = "cu126" },
|
||||
{ index = "pytorch-cu124", extra = "cu124" },
|
||||
{ index = "pytorch-cu118", extra = "cu118" },
|
||||
{ index = "pytorch-cpu", extra = "cpu" },
|
||||
{ index = "pytorch-rocm", extra = "rocm" },
|
||||
{ index = "pytorch-xpu", extra = "xpus" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu126"
|
||||
url = "https://download.pytorch.org/whl/cu126"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu124"
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu118"
|
||||
url = "https://download.pytorch.org/whl/cu118"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-rocm"
|
||||
url = "https://download.pytorch.org/whl/rocm6.2"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-xpu"
|
||||
url = "https://download.pytorch.org/whl/xpu"
|
||||
explicit = true
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.10.17
|
||||
comfyui-frontend-package==1.11.8
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
||||
@@ -70,7 +70,7 @@ def test_get_release_invalid_version(mock_provider):
|
||||
def test_init_frontend_default():
|
||||
version_string = DEFAULT_VERSION_STRING
|
||||
frontend_path = FrontendManager.init_frontend(version_string)
|
||||
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
|
||||
assert frontend_path == FrontendManager.default_frontend_path()
|
||||
|
||||
|
||||
def test_init_frontend_invalid_version():
|
||||
@@ -84,24 +84,29 @@ def test_init_frontend_invalid_provider():
|
||||
with pytest.raises(HTTPError):
|
||||
FrontendManager.init_frontend_unsafe(version_string)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_os_functions():
|
||||
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
|
||||
patch('app.frontend_management.os.listdir') as mock_listdir, \
|
||||
patch('app.frontend_management.os.rmdir') as mock_rmdir:
|
||||
with (
|
||||
patch("app.frontend_management.os.makedirs") as mock_makedirs,
|
||||
patch("app.frontend_management.os.listdir") as mock_listdir,
|
||||
patch("app.frontend_management.os.rmdir") as mock_rmdir,
|
||||
):
|
||||
mock_listdir.return_value = [] # Simulate empty directory
|
||||
yield mock_makedirs, mock_listdir, mock_rmdir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_download():
|
||||
with patch('app.frontend_management.download_release_asset_zip') as mock:
|
||||
with patch("app.frontend_management.download_release_asset_zip") as mock:
|
||||
mock.side_effect = Exception("Download failed") # Simulate download failure
|
||||
yield mock
|
||||
|
||||
|
||||
def test_finally_block(mock_os_functions, mock_download, mock_provider):
|
||||
# Arrange
|
||||
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
|
||||
version_string = 'test-owner/test-repo@1.0.0'
|
||||
version_string = "test-owner/test-repo@1.0.0"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception):
|
||||
@@ -128,3 +133,42 @@ def test_parse_version_string_invalid():
|
||||
version_string = "invalid"
|
||||
with pytest.raises(argparse.ArgumentTypeError):
|
||||
FrontendManager.parse_version_string(version_string)
|
||||
|
||||
|
||||
def test_init_frontend_default_with_mocks():
|
||||
# Arrange
|
||||
version_string = DEFAULT_VERSION_STRING
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||
patch.object(
|
||||
FrontendManager, "default_frontend_path", return_value="/mocked/path"
|
||||
),
|
||||
):
|
||||
frontend_path = FrontendManager.init_frontend(version_string)
|
||||
|
||||
# Assert
|
||||
assert frontend_path == "/mocked/path"
|
||||
mock_check.assert_called_once()
|
||||
|
||||
|
||||
def test_init_frontend_fallback_on_error():
|
||||
# Arrange
|
||||
version_string = "test-owner/test-repo@1.0.0"
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch.object(
|
||||
FrontendManager, "init_frontend_unsafe", side_effect=Exception("Test error")
|
||||
),
|
||||
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||
patch.object(
|
||||
FrontendManager, "default_frontend_path", return_value="/default/path"
|
||||
),
|
||||
):
|
||||
frontend_path = FrontendManager.init_frontend(version_string)
|
||||
|
||||
# Assert
|
||||
assert frontend_path == "/default/path"
|
||||
mock_check.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user