Compare commits

..

19 Commits

Author SHA1 Message Date
Chenlei Hu
522d923948 nit 2025-03-25 16:47:52 -04:00
Chenlei Hu
c05c9b552b nit 2025-03-25 16:47:42 -04:00
Chenlei Hu
27598702e9 [Type] Annotate graph.get_input_info 2025-03-25 16:44:55 -04:00
comfyanonymous
8edc1f44c1 Support more float8 types. 2025-03-25 05:23:49 -04:00
comfyanonymous
eade1551bb Add Hunyuan3D to readme. 2025-03-24 07:14:32 -04:00
comfyanonymous
581a9991ff Add model merging node for WAN 2.1 2025-03-23 08:06:36 -04:00
comfyanonymous
e471c726e5 Fallback to pytorch attention if sage attention fails. 2025-03-22 15:45:56 -04:00
comfyanonymous
75c1c757d9 ComfyUI version v0.3.27 2025-03-21 20:09:54 -04:00
Chenlei Hu
ce9b084279 [nit] Format error strings (#7345) 2025-03-21 19:08:25 -04:00
Terry Jia
2206246055 support output normal and lineart once (#7290) 2025-03-21 16:24:13 -04:00
comfyanonymous
d9fa9d307f Automatically set the right sampling type for lotus. 2025-03-21 14:19:37 -04:00
thot experiment
83e839a89b Native LotusD Implementation (#7125)
* draft pass at a native comfy implementation of Lotus-D depth and normal est

* fix model_sampling kludges

* fix ruff

---------

Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
2025-03-21 14:04:15 -04:00
Chenlei Hu
0cf2274699 Update frontend to 1.14 (#7343) 2025-03-21 13:50:09 -04:00
comfyanonymous
0956107170 Nodes to convert images to YUV and back.
Can be used to convert an image to black and white.
2025-03-21 06:32:44 -04:00
Chenlei Hu
a4a956dbbd Add backend primitive nodes (#7328)
* Add backend primitive nodes

* Add control after generate to int primitive
2025-03-21 01:47:18 -04:00
Chenlei Hu
8b9ce4ed18 Update frontend to 1.13 (#7331) 2025-03-21 00:17:36 -04:00
comfyanonymous
3872b43d4b A few fixes for the hunyuan3d models. 2025-03-20 04:52:31 -04:00
comfyanonymous
32ca0805b7 Fix orientation of hunyuan 3d model. 2025-03-19 19:55:24 -04:00
comfyanonymous
11f1b41bab Initial Hunyuan3Dv2 implementation.
Supports the multiview, mini, turbo models and VAEs.
2025-03-19 16:52:58 -04:00
28 changed files with 1612 additions and 255 deletions

View File

@@ -69,6 +69,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- 3D Models
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- Asynchronous Queue system - Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Many optimizations: Only re-executes the parts of the workflow that changes between executions.

View File

@@ -22,13 +22,21 @@ import app.logger
# The path to the requirements.txt file # The path to the requirements.txt file
req_path = Path(__file__).parents[1] / "requirements.txt" req_path = Path(__file__).parents[1] / "requirements.txt"
def frontend_install_warning_message(): def frontend_install_warning_message():
"""The warning message to display when the frontend version is not up to date.""" """The warning message to display when the frontend version is not up to date."""
extra = "" extra = ""
if sys.flags.no_user_site: if sys.flags.no_user_site:
extra = "-s " extra = "-s "
return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem" return f"""
Please install the updated requirements.txt file by running:
{sys.executable} {extra}-m pip install -r {req_path}
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
""".strip()
def check_frontend_version(): def check_frontend_version():
@@ -43,7 +51,17 @@ def check_frontend_version():
with open(req_path, "r", encoding="utf-8") as f: with open(req_path, "r", encoding="utf-8") as f:
required_frontend = parse_version(f.readline().split("=")[-1]) required_frontend = parse_version(f.readline().split("=")[-1])
if frontend_version < required_frontend: if frontend_version < required_frontend:
app.logger.log_startup_warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message())) app.logger.log_startup_warning(
f"""
________________________________________________________________________
WARNING WARNING WARNING WARNING WARNING
Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
{frontend_install_warning_message()}
________________________________________________________________________
""".strip()
)
else: else:
logging.info("ComfyUI frontend version: {}".format(frontend_version_str)) logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
except Exception as e: except Exception as e:
@@ -150,9 +168,20 @@ class FrontendManager:
def default_frontend_path(cls) -> str: def default_frontend_path(cls) -> str:
try: try:
import comfyui_frontend_package import comfyui_frontend_package
return str(importlib.resources.files(comfyui_frontend_package) / "static") return str(importlib.resources.files(comfyui_frontend_package) / "static")
except ImportError: except ImportError:
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n") logging.error(
f"""
********** ERROR ***********
comfyui-frontend-package is not installed.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
)
sys.exit(-1) sys.exit(-1)
@classmethod @classmethod
@@ -175,7 +204,9 @@ class FrontendManager:
return match_result.group(1), match_result.group(2), match_result.group(3) return match_result.group(1), match_result.group(2), match_result.group(3)
@classmethod @classmethod
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str: def init_frontend_unsafe(
cls, version_string: str, provider: Optional[FrontEndProvider] = None
) -> str:
""" """
Initializes the frontend for the specified version. Initializes the frontend for the specified version.
@@ -197,12 +228,20 @@ class FrontendManager:
repo_owner, repo_name, version = cls.parse_version_string(version_string) repo_owner, repo_name, version = cls.parse_version_string(version_string)
if version.startswith("v"): if version.startswith("v"):
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v")) expected_path = str(
Path(cls.CUSTOM_FRONTENDS_ROOT)
/ f"{repo_owner}_{repo_name}"
/ version.lstrip("v")
)
if os.path.exists(expected_path): if os.path.exists(expected_path):
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}") logging.info(
f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
)
return expected_path return expected_path
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...") logging.info(
f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
)
provider = provider or FrontEndProvider(repo_owner, repo_name) provider = provider or FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version) release = provider.get_release(version)

View File

@@ -456,3 +456,13 @@ class Wan21(LatentFormat):
latents_mean = self.latents_mean.to(latent.device, latent.dtype) latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype) latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean return latent * latents_std / self.scale_factor + latents_mean
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 0.9990943042622529
class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 1.0188137142395404

View File

@@ -0,0 +1,135 @@
import torch
from torch import nn
from comfy.ldm.flux.layers import (
DoubleStreamBlock,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
class Hunyuan3Dv2(nn.Module):
def __init__(
self,
in_channels=64,
context_in_dim=1536,
hidden_size=1024,
mlp_ratio=4.0,
num_heads=16,
depth=16,
depth_single_blocks=32,
qkv_bias=True,
guidance_embed=False,
image_model=None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.dtype = dtype
if hidden_size % num_heads != 0:
raise ValueError(
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
)
self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead
self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None
)
self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
dtype=dtype, device=device, operations=operations
)
for _ in range(depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
dtype=dtype, device=device, operations=operations
)
for _ in range(depth_single_blocks)
]
)
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
x = x.movedim(-1, -2)
timestep = 1.0 - timestep
txt = context
img = self.latent_in(x)
vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype))
if self.guidance_in is not None:
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype))
txt = self.cond_in(txt)
pe = None
attn_mask = None
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
img = img[:, txt.shape[1]:, ...]
img = self.final_layer(img, vec)
return img.movedim(-2, -1) * (-1.0)

587
comfy/ldm/hunyuan3d/vae.py Normal file
View File

@@ -0,0 +1,587 @@
# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py
# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, Tuple, List, Callable, Optional
import numpy as np
from einops import repeat, rearrange
from tqdm import tqdm
import logging
import comfy.ops
ops = comfy.ops.disable_weight_init
def generate_dense_grid_points(
bbox_min: np.ndarray,
bbox_max: np.ndarray,
octree_resolution: int,
indexing: str = "ij",
):
length = bbox_max - bbox_min
num_cells = octree_resolution
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
xyz = np.stack((xs, ys, zs), axis=-1)
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
return xyz, grid_size, length
class VanillaVolumeDecoder:
@torch.no_grad()
def __call__(
self,
latents: torch.FloatTensor,
geo_decoder: Callable,
bounds: Union[Tuple[float], List[float], float] = 1.01,
num_chunks: int = 10000,
octree_resolution: int = None,
enable_pbar: bool = True,
**kwargs,
):
device = latents.device
dtype = latents.dtype
batch_size = latents.shape[0]
# 1. generate query points
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
xyz_samples, grid_size, length = generate_dense_grid_points(
bbox_min=bbox_min,
bbox_max=bbox_max,
octree_resolution=octree_resolution,
indexing="ij"
)
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
# 2. latents to 3d volume
batch_logits = []
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
disable=not enable_pbar):
chunk_queries = xyz_samples[start: start + num_chunks, :]
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
logits = geo_decoder(queries=chunk_queries, latents=latents)
batch_logits.append(logits)
grid_logits = torch.cat(batch_logits, dim=1)
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
return grid_logits
class FourierEmbedder(nn.Module):
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
each feature dimension of `x[..., i]` into:
[
sin(x[..., i]),
sin(f_1*x[..., i]),
sin(f_2*x[..., i]),
...
sin(f_N * x[..., i]),
cos(x[..., i]),
cos(f_1*x[..., i]),
cos(f_2*x[..., i]),
...
cos(f_N * x[..., i]),
x[..., i] # only present if include_input is True.
], here f_i is the frequency.
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
Args:
num_freqs (int): the number of frequencies, default is 6;
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
input_dim (int): the input dimension, default is 3;
include_input (bool): include the input tensor or not, default is True.
Attributes:
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
otherwise, it is input_dim * num_freqs * 2.
"""
def __init__(self,
num_freqs: int = 6,
logspace: bool = True,
input_dim: int = 3,
include_input: bool = True,
include_pi: bool = True) -> None:
"""The initialization"""
super().__init__()
if logspace:
frequencies = 2.0 ** torch.arange(
num_freqs,
dtype=torch.float32
)
else:
frequencies = torch.linspace(
1.0,
2.0 ** (num_freqs - 1),
num_freqs,
dtype=torch.float32
)
if include_pi:
frequencies *= torch.pi
self.register_buffer("frequencies", frequencies, persistent=False)
self.include_input = include_input
self.num_freqs = num_freqs
self.out_dim = self.get_dims(input_dim)
def get_dims(self, input_dim):
temp = 1 if self.include_input or self.num_freqs == 0 else 0
out_dim = input_dim * (self.num_freqs * 2 + temp)
return out_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" Forward process.
Args:
x: tensor of shape [..., dim]
Returns:
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
where temp is 1 if include_input is True and 0 otherwise.
"""
if self.num_freqs > 0:
embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1)
if self.include_input:
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
else:
return torch.cat((embed.sin(), embed.cos()), dim=-1)
else:
return x
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
out = F.scaled_dot_product_attention(q, k, v)
return out
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if self.drop_prob == 0. or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and self.scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
def extra_repr(self):
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
class MLP(nn.Module):
def __init__(
self, *,
width: int,
expand_ratio: int = 4,
output_width: int = None,
drop_path_rate: float = 0.0
):
super().__init__()
self.width = width
self.c_fc = ops.Linear(width, width * expand_ratio)
self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width)
self.gelu = nn.GELU()
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
class QKVMultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
heads: int,
width=None,
qk_norm=False,
norm_layer=ops.LayerNorm
):
super().__init__()
self.heads = heads
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.attn_processor = CrossAttentionProcessor()
def forward(self, q, kv):
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
attn_ch = width // self.heads // 2
q = q.view(bs, n_ctx, self.heads, -1)
kv = kv.view(bs, n_data, self.heads, -1)
k, v = torch.split(kv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = self.attn_processor(self, q, k, v)
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
return out
class MultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
width: int,
heads: int,
qkv_bias: bool = True,
data_width: Optional[int] = None,
norm_layer=ops.LayerNorm,
qk_norm: bool = False,
kv_cache: bool = False,
):
super().__init__()
self.width = width
self.heads = heads
self.data_width = width if data_width is None else data_width
self.c_q = ops.Linear(width, width, bias=qkv_bias)
self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias)
self.c_proj = ops.Linear(width, width)
self.attention = QKVMultiheadCrossAttention(
heads=heads,
width=width,
norm_layer=norm_layer,
qk_norm=qk_norm
)
self.kv_cache = kv_cache
self.data = None
def forward(self, x, data):
x = self.c_q(x)
if self.kv_cache:
if self.data is None:
self.data = self.c_kv(data)
logging.info('Save kv cache,this should be called only once for one mesh')
data = self.data
else:
data = self.c_kv(data)
x = self.attention(x, data)
x = self.c_proj(x)
return x
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
*,
width: int,
heads: int,
mlp_expand_ratio: int = 4,
data_width: Optional[int] = None,
qkv_bias: bool = True,
norm_layer=ops.LayerNorm,
qk_norm: bool = False
):
super().__init__()
if data_width is None:
data_width = width
self.attn = MultiheadCrossAttention(
width=width,
heads=heads,
data_width=data_width,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm
)
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
def forward(self, x: torch.Tensor, data: torch.Tensor):
x = x + self.attn(self.ln_1(x), self.ln_2(data))
x = x + self.mlp(self.ln_3(x))
return x
class QKVMultiheadAttention(nn.Module):
def __init__(
self,
*,
heads: int,
width=None,
qk_norm=False,
norm_layer=ops.LayerNorm
):
super().__init__()
self.heads = heads
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
def forward(self, qkv):
bs, n_ctx, width = qkv.shape
attn_ch = width // self.heads // 3
qkv = qkv.view(bs, n_ctx, self.heads, -1)
q, k, v = torch.split(qkv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
return out
class MultiheadAttention(nn.Module):
def __init__(
self,
*,
width: int,
heads: int,
qkv_bias: bool,
norm_layer=ops.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0
):
super().__init__()
self.width = width
self.heads = heads
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
self.c_proj = ops.Linear(width, width)
self.attention = QKVMultiheadAttention(
heads=heads,
width=width,
norm_layer=norm_layer,
qk_norm=qk_norm
)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
x = self.c_qkv(x)
x = self.attention(x)
x = self.drop_path(self.c_proj(x))
return x
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
*,
width: int,
heads: int,
qkv_bias: bool = True,
norm_layer=ops.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0,
):
super().__init__()
self.attn = MultiheadAttention(
width=width,
heads=heads,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
def forward(self, x: torch.Tensor):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self,
*,
width: int,
layers: int,
heads: int,
qkv_bias: bool = True,
norm_layer=ops.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
width=width,
heads=heads,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor):
for block in self.resblocks:
x = block(x)
return x
class CrossAttentionDecoder(nn.Module):
def __init__(
self,
*,
out_channels: int,
fourier_embedder: FourierEmbedder,
width: int,
heads: int,
mlp_expand_ratio: int = 4,
downsample_ratio: int = 1,
enable_ln_post: bool = True,
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary"
):
super().__init__()
self.enable_ln_post = enable_ln_post
self.fourier_embedder = fourier_embedder
self.downsample_ratio = downsample_ratio
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
if self.downsample_ratio != 1:
self.latents_proj = ops.Linear(width * downsample_ratio, width)
if self.enable_ln_post == False:
qk_norm = False
self.cross_attn_decoder = ResidualCrossAttentionBlock(
width=width,
mlp_expand_ratio=mlp_expand_ratio,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm
)
if self.enable_ln_post:
self.ln_post = ops.LayerNorm(width)
self.output_proj = ops.Linear(width, out_channels)
self.label_type = label_type
self.count = 0
def forward(self, queries=None, query_embeddings=None, latents=None):
if query_embeddings is None:
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
self.count += query_embeddings.shape[1]
if self.downsample_ratio != 1:
latents = self.latents_proj(latents)
x = self.cross_attn_decoder(query_embeddings, latents)
if self.enable_ln_post:
x = self.ln_post(x)
occ = self.output_proj(x)
return occ
class ShapeVAE(nn.Module):
def __init__(
self,
*,
embed_dim: int,
width: int,
heads: int,
num_decoder_layers: int,
geo_decoder_downsample_ratio: int = 1,
geo_decoder_mlp_expand_ratio: int = 4,
geo_decoder_ln_post: bool = True,
num_freqs: int = 8,
include_pi: bool = True,
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary",
drop_path_rate: float = 0.0,
scale_factor: float = 1.0,
):
super().__init__()
self.geo_decoder_ln_post = geo_decoder_ln_post
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
self.post_kl = ops.Linear(embed_dim, width)
self.transformer = Transformer(
width=width,
layers=num_decoder_layers,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
self.geo_decoder = CrossAttentionDecoder(
fourier_embedder=self.fourier_embedder,
out_channels=1,
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
downsample_ratio=geo_decoder_downsample_ratio,
enable_ln_post=self.geo_decoder_ln_post,
width=width // geo_decoder_downsample_ratio,
heads=heads // geo_decoder_downsample_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
label_type=label_type,
)
self.volume_decoder = VanillaVolumeDecoder()
self.scale_factor = scale_factor
def decode(self, latents, **kwargs):
latents = self.post_kl(latents.movedim(-2, -1))
latents = self.transformer(latents)
bounds = kwargs.get("bounds", 1.01)
num_chunks = kwargs.get("num_chunks", 8000)
octree_resolution = kwargs.get("octree_resolution", 256)
enable_pbar = kwargs.get("enable_pbar", True)
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
return grid_logits.movedim(-2, -1)
def encode(self, x):
return None

View File

@@ -471,7 +471,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
tensor_layout="HND" tensor_layout = "HND"
else: else:
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
@@ -479,7 +479,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.view(b, -1, heads, dim_head), lambda t: t.view(b, -1, heads, dim_head),
(q, k, v), (q, k, v),
) )
tensor_layout="NHD" tensor_layout = "NHD"
if mask is not None: if mask is not None:
# add a batch dimension if there isn't already one # add a batch dimension if there isn't already one
@@ -489,7 +489,17 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
if mask.ndim == 3: if mask.ndim == 3:
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) try:
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
if tensor_layout == "NHD":
q, k, v = map(
lambda t: t.transpose(1, 2),
(q, k, v),
)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
if tensor_layout == "HND": if tensor_layout == "HND":
if not skip_output_reshape: if not skip_output_reshape:
out = ( out = (

View File

@@ -16,7 +16,6 @@
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
from __future__ import annotations
import torch import torch
import logging import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@@ -37,6 +36,7 @@ import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model import comfy.ldm.cosmos.model
import comfy.ldm.lumina.model import comfy.ldm.lumina.model
import comfy.ldm.wan.model import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model
import comfy.model_management import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
@@ -59,6 +59,7 @@ class ModelType(Enum):
FLOW = 6 FLOW = 6
V_PREDICTION_CONTINUOUS = 7 V_PREDICTION_CONTINUOUS = 7
FLUX = 8 FLUX = 8
IMG_TO_IMG = 9
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
@@ -89,6 +90,8 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.FLUX: elif model_type == ModelType.FLUX:
c = comfy.model_sampling.CONST c = comfy.model_sampling.CONST
s = comfy.model_sampling.ModelSamplingFlux s = comfy.model_sampling.ModelSamplingFlux
elif model_type == ModelType.IMG_TO_IMG:
c = comfy.model_sampling.IMG_TO_IMG
class ModelSampling(s, c): class ModelSampling(s, c):
pass pass
@@ -105,7 +108,7 @@ class BaseModel(torch.nn.Module):
self.model_config = model_config self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype self.manual_cast_dtype = model_config.manual_cast_dtype
self.device = device self.device = device
self.current_patcher: ModelPatcher = None self.current_patcher: 'ModelPatcher' = None
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None: if model_config.custom_operations is None:
@@ -129,7 +132,6 @@ class BaseModel(torch.nn.Module):
logging.info("model_type {}".format(model_type.name)) logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels)) logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor self.memory_usage_factor = model_config.memory_usage_factor
self.zipper_initialized = False
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor( return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -139,18 +141,9 @@ class BaseModel(torch.nn.Module):
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) ).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
# handle lowvram zipper initialization, if required
if self.model_lowvram and not self.zipper_initialized:
if self.current_patcher:
self.zipper_initialized = True
with self.current_patcher.use_ejected():
loading = self.current_patcher._load_list_lowvram_only()
return self._apply_model_inner(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def _apply_model_inner(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t sigma = t
xc = self.model_sampling.calculate_input(sigma, x) xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None: if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1) xc = torch.cat([xc] + [c_concat], dim=1)
@@ -612,6 +605,19 @@ class SDXL_instructpix2pix(IP2P, SDXL):
else: else:
self.process_ip2p_image_in = lambda image: image #diffusers ip2p self.process_ip2p_image_in = lambda image: image #diffusers ip2p
class Lotus(BaseModel):
def extra_conds(self, **kwargs):
out = {}
cross_attn = kwargs.get("cross_attn", None)
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
device = kwargs["device"]
task_emb = torch.tensor([1, 0]).float().to(device)
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)]).unsqueeze(0)
out['y'] = comfy.conds.CONDRegular(task_emb)
return out
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG, device=None):
super().__init__(model_config, model_type, device=device)
class StableCascade_C(BaseModel): class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
@@ -1025,3 +1031,18 @@ class WAN21(BaseModel):
if clip_vision_output is not None: if clip_vision_output is not None:
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states) out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
return out return out
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
guidance = kwargs.get("guidance", 5.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out

View File

@@ -154,7 +154,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = len(guidance_keys) > 0 dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
dit_config = {} dit_config = {}
dit_config["image_model"] = "flux" dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16 dit_config["in_channels"] = 16
@@ -323,6 +323,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "t2v" dit_config["model_type"] = "t2v"
return dit_config return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
dit_config = {}
dit_config["image_model"] = "hunyuan3d2"
dit_config["in_channels"] = in_shape[1]
dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
dit_config["hidden_size"] = in_shape[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 16
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["qkv_bias"] = True
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None return None
@@ -667,8 +682,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint] supported_models = [LotusD, SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
for unet_config in supported_models: for unet_config in supported_models:
matches = True matches = True

View File

@@ -46,6 +46,32 @@ cpu_state = CPUState.GPU
total_vram = 0 total_vram = 0
def get_supported_float8_types():
float8_types = []
try:
float8_types.append(torch.float8_e4m3fn)
except:
pass
try:
float8_types.append(torch.float8_e4m3fnuz)
except:
pass
try:
float8_types.append(torch.float8_e5m2)
except:
pass
try:
float8_types.append(torch.float8_e5m2fnuz)
except:
pass
try:
float8_types.append(torch.float8_e8m0fnu)
except:
pass
return float8_types
FLOAT8_TYPES = get_supported_float8_types()
xpu_available = False xpu_available = False
torch_version = "" torch_version = ""
try: try:
@@ -701,11 +727,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
return torch.float8_e5m2 return torch.float8_e5m2
fp8_dtype = None fp8_dtype = None
try: if weight_dtype in FLOAT8_TYPES:
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: fp8_dtype = weight_dtype
fp8_dtype = weight_dtype
except:
pass
if fp8_dtype is not None: if fp8_dtype is not None:
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive

View File

@@ -17,7 +17,7 @@
""" """
from __future__ import annotations from __future__ import annotations
from typing import Optional, Callable, TYPE_CHECKING from typing import Optional, Callable
import torch import torch
import copy import copy
import inspect import inspect
@@ -26,7 +26,6 @@ import uuid
import collections import collections
import math import math
import comfy.ops
import comfy.utils import comfy.utils
import comfy.float import comfy.float
import comfy.model_management import comfy.model_management
@@ -35,9 +34,6 @@ import comfy.hooks
import comfy.patcher_extension import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
from comfy.comfy_types import UnetWrapperFunction from comfy.comfy_types import UnetWrapperFunction
if TYPE_CHECKING:
from comfy.model_base import BaseModel
def string_to_seed(data): def string_to_seed(data):
crc = 0xFFFFFFFF crc = 0xFFFFFFFF
@@ -205,7 +201,7 @@ class MemoryCounter:
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size self.size = size
self.model: BaseModel = model self.model = model
if not hasattr(self.model, 'device'): if not hasattr(self.model, 'device'):
logging.debug("Model doesn't have a device attribute.") logging.debug("Model doesn't have a device attribute.")
self.model.device = offload_device self.model.device = offload_device
@@ -572,14 +568,6 @@ class ModelPatcher:
else: else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def _zipper_dict_lowvram_only(self):
loading = self._load_list_lowvram_only()
def _load_list_lowvram_only(self):
loading = self._load_list()
return [x for x in loading if hasattr(x[2], "prev_comfy_cast_weights")]
def _load_list(self): def _load_list(self):
loading = [] loading = []
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
@@ -595,35 +583,6 @@ class ModelPatcher:
loading.append((comfy.model_management.module_size(m), n, m, params)) loading.append((comfy.model_management.module_size(m), n, m, params))
return loading return loading
def prepare_teeth(self):
ordered_list = self._load_list_lowvram_only()
prev_i = None
next_i = None
# first, create teeth on modules in list
for l in ordered_list:
m: comfy.ops.CastWeightBiasOp = l[2]
m.init_tooth(self.load_device, self.offload_device, l[1])
# create teeth linked list
for i in range(len(ordered_list)):
if i+1 == len(ordered_list):
next_i = None
else:
next_i = i+1
m: comfy.ops.CastWeightBiasOp = ordered_list[i][2]
if prev_i is not None:
m.zipper_tooth.prev_tooth = ordered_list[prev_i][2].zipper_tooth
else:
m.zipper_tooth.start = True
if next_i is not None:
m.zipper_tooth.next_tooth = ordered_list[next_i][2].zipper_tooth
prev_i = i
def clean_teeth(self):
ordered_list = self._load_list_lowvram_only()
for l in ordered_list:
m: comfy.ops.CastWeightBiasOp = l[2]
m.clean_tooth()
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
with self.use_ejected(): with self.use_ejected():
self.unpatch_hooks() self.unpatch_hooks()
@@ -632,8 +591,6 @@ class ModelPatcher:
lowvram_counter = 0 lowvram_counter = 0
loading = self._load_list() loading = self._load_list()
logging.info(f"total size of _load_list: {sum([x[0] for x in loading])}")
load_completely = [] load_completely = []
loading.sort(reverse=True) loading.sort(reverse=True)
for x in loading: for x in loading:
@@ -715,7 +672,6 @@ class ModelPatcher:
if lowvram_counter > 0: if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True self.model.model_lowvram = True
self.model.zipper_initialized = False
else: else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False self.model.model_lowvram = False
@@ -728,9 +684,6 @@ class ModelPatcher:
self.model.model_loaded_weight_memory = mem_counter self.model.model_loaded_weight_memory = mem_counter
self.model.current_weight_patches_uuid = self.patches_uuid self.model.current_weight_patches_uuid = self.patches_uuid
if self.model.model_lowvram:
self.prepare_teeth()
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
@@ -762,7 +715,6 @@ class ModelPatcher:
move_weight_functions(m, device_to) move_weight_functions(m, device_to)
wipe_lowvram_weight(m) wipe_lowvram_weight(m)
self.clean_teeth()
self.model.model_lowvram = False self.model.model_lowvram = False
self.model.lowvram_patch_counter = 0 self.model.lowvram_patch_counter = 0
@@ -852,10 +804,8 @@ class ModelPatcher:
logging.debug("freed {}".format(n)) logging.debug("freed {}".format(n))
self.model.model_lowvram = True self.model.model_lowvram = True
self.model.zipper_initialized = False
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed self.model.model_loaded_weight_memory -= memory_freed
self.prepare_teeth()
return memory_freed return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):

View File

@@ -69,6 +69,15 @@ class CONST:
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1)) sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
return latent / (1.0 - sigma) return latent / (1.0 - sigma)
class X0(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
return model_output
class IMG_TO_IMG(X0):
def calculate_input(self, sigma, noise):
return noise
class ModelSamplingDiscrete(torch.nn.Module): class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None, zsnr=None): def __init__(self, model_config=None, zsnr=None):
super().__init__() super().__init__()

View File

@@ -16,7 +16,6 @@
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
from __future__ import annotations
import torch import torch
import logging import logging
import comfy.model_management import comfy.model_management
@@ -57,79 +56,6 @@ class CastWeightBiasOp:
comfy_cast_weights = False comfy_cast_weights = False
weight_function = [] weight_function = []
bias_function = [] bias_function = []
zipper_init: dict = None
zipper_tooth: ZipperTooth = None
_zipper_tooth: ZipperTooth = None
def init_tooth(self, load_device, offload_device, key: str=None):
if self.zipper_tooth:
self.clean_tooth()
self.zipper_tooth = ZipperTooth(self, load_device, offload_device, key)
def clean_tooth(self):
if self.zipper_tooth:
del self.zipper_tooth
self.zipper_tooth = None
def connect_teeth(self):
if self.zipper_init is not None:
self.zipper_init[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
self.zipper_dict["prev_zipper_key"] = self.zipper_key
# def zipper_connect(self):
# if self.zipper_dict is not None:
# self.zipper_dict[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
# self.zipper_dict["prev_zipper_key"] = self.zipper_key
class ZipperTooth:
def __init__(self, op: CastWeightBiasOp, load_device, offload_device, key: str=None):
self.op = op
self.key: str = key
self.weight_preloaded: torch.Tensor = None
self.bias_preloaded: torch.Tensor = None
self.load_device = load_device
self.offload_device = offload_device
self.start = False
self.prev_tooth: ZipperTooth = None
self.next_tooth: ZipperTooth = None
def get_bias_weight(self, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
try:
if self.start:
return cast_bias_weight(self.op, input, dtype, device, bias_dtype)
return self.weight_preloaded, self.bias_preloaded
finally:
# if self.prev_tooth:
# self.prev_tooth.offload_previous(0)
self.next_tooth.preload_next(0, input, dtype, device, bias_dtype)
def preload_next(self, teeth_count=1, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
# TODO: queue load of tensors
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
non_blocking = comfy.model_management.device_supports_non_blocking(self.load_device)
if self.op.bias is not None:
self.bias_preloaded = comfy.model_management.cast_to(self.op.bias, bias_dtype, device, non_blocking=non_blocking)
self.weight_preloaded = comfy.model_management.cast_to(self.op.weight, dtype, device, non_blocking=non_blocking)
if self.next_tooth and teeth_count:
self.next_tooth.preload_next(teeth_count-1)
def offload_previous(self, teeth_count=1):
# TODO: queue offload of tensors
self.weight_preloaded = None
self.bias_preloaded = None
if self.prev_tooth and teeth_count:
self.prev_tooth.offload_previous(teeth_count-1)
class disable_weight_init: class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp): class Linear(torch.nn.Linear, CastWeightBiasOp):
@@ -137,11 +63,7 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
#if self.zipper_init: weight, bias = cast_bias_weight(self, input)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@@ -155,10 +77,7 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.zipper_tooth: weight, bias = cast_bias_weight(self, input)
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias) return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@@ -172,10 +91,7 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.zipper_tooth: weight, bias = cast_bias_weight(self, input)
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias) return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@@ -189,10 +105,7 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.zipper_tooth: weight, bias = cast_bias_weight(self, input)
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias) return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@@ -206,10 +119,7 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.zipper_tooth: weight, bias = cast_bias_weight(self, input)
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@@ -224,10 +134,7 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.weight is not None: if self.weight is not None:
if self.zipper_tooth: weight, bias = cast_bias_weight(self, input)
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
else: else:
weight = None weight = None
bias = None bias = None
@@ -249,10 +156,7 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size, input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation) num_spatial_dims, self.dilation)
if self.zipper_tooth: weight, bias = cast_bias_weight(self, input)
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d( return torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding, input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
@@ -273,10 +177,7 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size, input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation) num_spatial_dims, self.dilation)
if self.zipper_tooth: weight, bias = cast_bias_weight(self, input)
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d( return torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding, input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
@@ -296,10 +197,7 @@ class disable_weight_init:
output_dtype = out_dtype output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None out_dtype = None
if self.zipper_tooth: weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
weight, bias = self.zipper_tooth.get_bias_weight(device=input.device, dtype=out_dtype)
else:
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):

View File

@@ -6,7 +6,6 @@ if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase from comfy.controlnet import ControlBase
from comfy.ops import CastWeightBiasOp
import torch import torch
from functools import partial from functools import partial
import collections import collections
@@ -19,7 +18,6 @@ import comfy.patcher_extension
import comfy.hooks import comfy.hooks
import scipy.stats import scipy.stats
import numpy import numpy
import comfy.ops
def add_area_dims(area, num_dims): def add_area_dims(area, num_dims):
@@ -362,38 +360,15 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns denoised #Returns denoised
def sampling_function(model: BaseModel, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None uncond_ = None
else: else:
uncond_ = uncond uncond_ = uncond
do_cleanup = False
if "weight_zipper" not in model_options:
do_cleanup = True
#zipper_dict = {}
model_options["weight_zipper"] = True
loaded_modules = model.current_patcher._load_list_lowvram_only()
low_m = [x for x in loaded_modules if hasattr(x[2], "prev_comfy_cast_weights")]
sum_m = sum([x[0] for x in low_m])
for l in loaded_modules:
m: CastWeightBiasOp = l[2]
if hasattr(m, "comfy_cast_weights"):
m.zipper_tooth = comfy.ops.ZipperTooth
#m.zipper_dict = zipper_dict
m.zipper_key = l[1]
conds = [cond, uncond_] conds = [cond, uncond_]
out = calc_cond_batch(model, conds, x, timestep, model_options) out = calc_cond_batch(model, conds, x, timestep, model_options)
if do_cleanup:
zzz = 20
for l in loaded_modules:
m: CastWeightBiasOp = l[2]
if hasattr(l[2], "comfy_cast_weights"):
#m.zipper_dict = None
m.zipper_key = None
for fn in model_options.get("sampler_pre_cfg_function", []): for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep, args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
"input": x, "sigma": timestep, "model": model, "model_options": model_options} "input": x, "sigma": timestep, "model": model, "model_options": model_options}

View File

@@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae import comfy.ldm.wan.vae
import comfy.ldm.hunyuan3d.vae
import yaml import yaml
import math import math
@@ -412,6 +413,17 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
self.latent_dim = 1
ln_post = "geo_decoder.ln_post.weight" in sd
inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
else: else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None self.first_stage_model = None
@@ -498,7 +510,7 @@ class VAE:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in): def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid() self.throw_exception_if_invalid()
pixel_samples = None pixel_samples = None
try: try:
@@ -510,7 +522,7 @@ class VAE:
for x in range(0, samples_in.shape[0], batch_number): for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
if pixel_samples is None: if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out pixel_samples[x:x+batch_number] = out

View File

@@ -506,6 +506,22 @@ class SDXL_instructpix2pix(SDXL):
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device) return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
class LotusD(SD20):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"use_temporal_attention": False,
"adm_in_channels": 4,
"in_channels": 4,
}
unet_extra_config = {
"num_classes": 'sequential'
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.Lotus(self, device=device)
class SD3(supported_models_base.BASE): class SD3(supported_models_base.BASE):
unet_config = { unet_config = {
"in_channels": 16, "in_channels": 16,
@@ -959,6 +975,44 @@ class WAN21_I2V(WAN21_T2V):
out = model_base.WAN21(self, image_to_video=True, device=device) out = model_base.WAN21(self, image_to_video=True, device=device)
return out return out
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
}
unet_extra_config = {}
sampling_settings = {
"multiplier": 1.0,
"shift": 1.0,
}
memory_usage_factor = 3.5
clip_vision_prefix = "conditioner.main_image_encoder.model."
vae_key_prefix = ["vae."]
latent_format = latent_formats.Hunyuan3Dv2
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Hunyuan3Dv2(self, device=device)
return out
def clip_target(self, state_dict={}):
return None
class Hunyuan3Dv2mini(Hunyuan3Dv2):
unet_config = {
"image_model": "hunyuan3d2",
"depth": 8,
}
latent_format = latent_formats.Hunyuan3Dv2mini
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@@ -1,6 +1,9 @@
import nodes from __future__ import annotations
from typing import Type, Literal
import nodes
from comfy_execution.graph_utils import is_link from comfy_execution.graph_utils import is_link
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
class DependencyCycleError(Exception): class DependencyCycleError(Exception):
pass pass
@@ -54,7 +57,22 @@ class DynamicPrompt:
def get_original_prompt(self): def get_original_prompt(self):
return self.original_prompt return self.original_prompt
def get_input_info(class_def, input_name, valid_inputs=None): def get_input_info(
class_def: Type[ComfyNodeABC],
input_name: str,
valid_inputs: InputTypeDict | None = None
) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
"""Get the input type, category, and extra info for a given input name.
Arguments:
class_def: The class definition of the node.
input_name: The name of the input to get info for.
valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES().
Returns:
tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name.
"""
valid_inputs = valid_inputs or class_def.INPUT_TYPES() valid_inputs = valid_inputs or class_def.INPUT_TYPES()
input_info = None input_info = None
input_category = None input_category = None
@@ -126,7 +144,7 @@ class TopologicalSort:
from_node_id, from_socket = value from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes: if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name) _, _, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id) node_ids.append(from_node_id)

View File

@@ -0,0 +1,415 @@
import torch
import os
import json
import struct
import numpy as np
from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch
import folder_paths
import comfy.model_management
from comfy.cli_args import args
class EmptyLatentHunyuan3Dv2:
@classmethod
def INPUT_TYPES(s):
return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/3d"
def generate(self, resolution, batch_size):
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
return ({"samples": latent, "type": "hunyuan3dv2"}, )
class Hunyuan3Dv2Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, clip_vision_output):
embeds = clip_vision_output.last_hidden_state
positive = [[embeds, {}]]
negative = [[torch.zeros_like(embeds), {}]]
return (positive, negative)
class Hunyuan3Dv2ConditioningMultiView:
@classmethod
def INPUT_TYPES(s):
return {"required": {},
"optional": {"front": ("CLIP_VISION_OUTPUT",),
"left": ("CLIP_VISION_OUTPUT",),
"back": ("CLIP_VISION_OUTPUT",),
"right": ("CLIP_VISION_OUTPUT",), }}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, front=None, left=None, back=None, right=None):
all_embeds = [front, left, back, right]
out = []
pos_embeds = None
for i, e in enumerate(all_embeds):
if e is not None:
if pos_embeds is None:
pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4))
out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1))
embeds = torch.cat(out, dim=1)
positive = [[embeds, {}]]
negative = [[torch.zeros_like(embeds), {}]]
return (positive, negative)
class VOXEL:
def __init__(self, data):
self.data = data
class VAEDecodeHunyuan3D:
@classmethod
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT", ),
"vae": ("VAE", ),
"num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}),
"octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}),
}}
RETURN_TYPES = ("VOXEL",)
FUNCTION = "decode"
CATEGORY = "latent/3d"
def decode(self, vae, samples, num_chunks, octree_resolution):
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
return (voxels, )
def voxel_to_mesh(voxels, threshold=0.5, device=None):
if device is None:
device = torch.device("cpu")
voxels = voxels.to(device)
binary = (voxels > threshold).float()
padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0)
D, H, W = binary.shape
neighbors = torch.tensor([
[0, 0, 1],
[0, 0, -1],
[0, 1, 0],
[0, -1, 0],
[1, 0, 0],
[-1, 0, 0]
], device=device)
z, y, x = torch.meshgrid(
torch.arange(D, device=device),
torch.arange(H, device=device),
torch.arange(W, device=device),
indexing='ij'
)
voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
solid_mask = binary.flatten() > 0
solid_indices = voxel_indices[solid_mask]
corner_offsets = [
torch.tensor([
[0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1]
], device=device),
torch.tensor([
[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]
], device=device),
torch.tensor([
[0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1]
], device=device),
torch.tensor([
[0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0]
], device=device),
torch.tensor([
[1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0]
], device=device),
torch.tensor([
[0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0]
], device=device)
]
all_vertices = []
all_indices = []
vertex_count = 0
for face_idx, offset in enumerate(neighbors):
neighbor_indices = solid_indices + offset
padded_indices = neighbor_indices + 1
is_exposed = padded[
padded_indices[:, 0],
padded_indices[:, 1],
padded_indices[:, 2]
] == 0
if not is_exposed.any():
continue
exposed_indices = solid_indices[is_exposed]
corners = corner_offsets[face_idx].unsqueeze(0)
face_vertices = exposed_indices.unsqueeze(1) + corners
all_vertices.append(face_vertices.reshape(-1, 3))
num_faces = exposed_indices.shape[0]
face_indices = torch.arange(
vertex_count,
vertex_count + 4 * num_faces,
device=device
).reshape(-1, 4)
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 1], face_indices[:, 2]], dim=1))
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 3]], dim=1))
vertex_count += 4 * num_faces
if len(all_vertices) > 0:
vertices = torch.cat(all_vertices, dim=0)
faces = torch.cat(all_indices, dim=0)
else:
vertices = torch.zeros((1, 3))
faces = torch.zeros((1, 3))
v_min = 0
v_max = max(voxels.shape)
vertices = vertices - (v_min + v_max) / 2
scale = (v_max - v_min) / 2
if scale > 0:
vertices = vertices / scale
vertices = torch.fliplr(vertices)
return vertices, faces
class MESH:
def __init__(self, vertices, faces):
self.vertices = vertices
self.faces = faces
class VoxelToMeshBasic:
@classmethod
def INPUT_TYPES(s):
return {"required": {"voxel": ("VOXEL", ),
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MESH",)
FUNCTION = "decode"
CATEGORY = "3d"
def decode(self, voxel, threshold):
vertices = []
faces = []
for x in voxel.data:
v, f = voxel_to_mesh(x, threshold=threshold, device=None)
vertices.append(v)
faces.append(f)
return (MESH(torch.stack(vertices), torch.stack(faces)), )
def save_glb(vertices, faces, filepath, metadata=None):
"""
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
Parameters:
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
faces: torch.Tensor of shape (M, 4) or (M, 3) - The face indices (quad or triangle faces)
filepath: str - Output filepath (should end with .glb)
"""
# Convert tensors to numpy arrays
vertices_np = vertices.cpu().numpy().astype(np.float32)
faces_np = faces.cpu().numpy().astype(np.uint32)
vertices_buffer = vertices_np.tobytes()
indices_buffer = faces_np.tobytes()
def pad_to_4_bytes(buffer):
padding_length = (4 - (len(buffer) % 4)) % 4
return buffer + b'\x00' * padding_length
vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
indices_buffer_padded = pad_to_4_bytes(indices_buffer)
buffer_data = vertices_buffer_padded + indices_buffer_padded
vertices_byte_length = len(vertices_buffer)
vertices_byte_offset = 0
indices_byte_length = len(indices_buffer)
indices_byte_offset = len(vertices_buffer_padded)
gltf = {
"asset": {"version": "2.0", "generator": "ComfyUI"},
"buffers": [
{
"byteLength": len(buffer_data)
}
],
"bufferViews": [
{
"buffer": 0,
"byteOffset": vertices_byte_offset,
"byteLength": vertices_byte_length,
"target": 34962 # ARRAY_BUFFER
},
{
"buffer": 0,
"byteOffset": indices_byte_offset,
"byteLength": indices_byte_length,
"target": 34963 # ELEMENT_ARRAY_BUFFER
}
],
"accessors": [
{
"bufferView": 0,
"byteOffset": 0,
"componentType": 5126, # FLOAT
"count": len(vertices_np),
"type": "VEC3",
"max": vertices_np.max(axis=0).tolist(),
"min": vertices_np.min(axis=0).tolist()
},
{
"bufferView": 1,
"byteOffset": 0,
"componentType": 5125, # UNSIGNED_INT
"count": faces_np.size,
"type": "SCALAR"
}
],
"meshes": [
{
"primitives": [
{
"attributes": {
"POSITION": 0
},
"indices": 1,
"mode": 4 # TRIANGLES
}
]
}
],
"nodes": [
{
"mesh": 0
}
],
"scenes": [
{
"nodes": [0]
}
],
"scene": 0
}
if metadata is not None:
gltf["asset"]["extras"] = metadata
# Convert the JSON to bytes
gltf_json = json.dumps(gltf).encode('utf8')
def pad_json_to_4_bytes(buffer):
padding_length = (4 - (len(buffer) % 4)) % 4
return buffer + b' ' * padding_length
gltf_json_padded = pad_json_to_4_bytes(gltf_json)
# Create the GLB header
# Magic glTF
glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data))
# Create JSON chunk header (chunk type 0)
json_chunk_header = struct.pack('<II', len(gltf_json_padded), 0x4E4F534A) # "JSON" in little endian
# Create BIN chunk header (chunk type 1)
bin_chunk_header = struct.pack('<II', len(buffer_data), 0x004E4942) # "BIN\0" in little endian
# Write the GLB file
with open(filepath, 'wb') as f:
f.write(glb_header)
f.write(json_chunk_header)
f.write(gltf_json_padded)
f.write(bin_chunk_header)
f.write(buffer_data)
return filepath
class SaveGLB:
@classmethod
def INPUT_TYPES(s):
return {"required": {"mesh": ("MESH", ),
"filename_prefix": ("STRING", {"default": "mesh/ComfyUI"}), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, }
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "3d"
def save(self, mesh, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
results = []
metadata = {}
if not args.disable_metadata:
if prompt is not None:
metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
for i in range(mesh.vertices.shape[0]):
f = f"{filename}_{counter:05}_.glb"
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
results.append({
"filename": f,
"subfolder": subfolder,
"type": "output"
})
counter += 1
return {"ui": {"3d": results}}
NODE_CLASS_MAPPINGS = {
"EmptyLatentHunyuan3Dv2": EmptyLatentHunyuan3Dv2,
"Hunyuan3Dv2Conditioning": Hunyuan3Dv2Conditioning,
"Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView,
"VAEDecodeHunyuan3D": VAEDecodeHunyuan3D,
"VoxelToMeshBasic": VoxelToMeshBasic,
"SaveGLB": SaveGLB,
}

View File

@@ -21,8 +21,8 @@ class Load3D():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}} }}
RETURN_TYPES = ("IMAGE", "MASK", "STRING") RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE")
RETURN_NAMES = ("image", "mask", "mesh_path") RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart")
FUNCTION = "process" FUNCTION = "process"
EXPERIMENTAL = True EXPERIMENTAL = True
@@ -32,12 +32,16 @@ class Load3D():
def process(self, model_file, image, **kwargs): def process(self, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image['image']) image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask']) mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
lineart_path = folder_paths.get_annotated_filepath(image['lineart'])
load_image_node = nodes.LoadImage() load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path) output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path) ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
return output_image, output_mask, model_file, return output_image, output_mask, model_file, normal_image, lineart_image
class Load3DAnimation(): class Load3DAnimation():
@classmethod @classmethod
@@ -55,8 +59,8 @@ class Load3DAnimation():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}} }}
RETURN_TYPES = ("IMAGE", "MASK", "STRING") RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE")
RETURN_NAMES = ("image", "mask", "mesh_path") RETURN_NAMES = ("image", "mask", "mesh_path", "normal")
FUNCTION = "process" FUNCTION = "process"
EXPERIMENTAL = True EXPERIMENTAL = True
@@ -66,12 +70,14 @@ class Load3DAnimation():
def process(self, model_file, image, **kwargs): def process(self, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image['image']) image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask']) mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
load_image_node = nodes.LoadImage() load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path) output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path) ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
return output_image, output_mask, model_file, return output_image, output_mask, model_file, normal_image
class Preview3D(): class Preview3D():
@classmethod @classmethod

File diff suppressed because one or more lines are too long

View File

@@ -20,10 +20,6 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input return c_out * x0 + c_skip * model_input
class X0(comfy.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
return model_output
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete): class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
original_timesteps = 50 original_timesteps = 50
@@ -56,7 +52,7 @@ class ModelSamplingDiscrete:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm", "x0"],), "sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],),
"zsnr": ("BOOLEAN", {"default": False}), "zsnr": ("BOOLEAN", {"default": False}),
}} }}
@@ -77,7 +73,9 @@ class ModelSamplingDiscrete:
sampling_type = LCM sampling_type = LCM
sampling_base = ModelSamplingDiscreteDistilled sampling_base = ModelSamplingDiscreteDistilled
elif sampling == "x0": elif sampling == "x0":
sampling_type = X0 sampling_type = comfy.model_sampling.X0
elif sampling == "img_to_img":
sampling_type = comfy.model_sampling.IMG_TO_IMG
class ModelSamplingAdvanced(sampling_base, sampling_type): class ModelSamplingAdvanced(sampling_base, sampling_type):
pass pass

View File

@@ -244,6 +244,30 @@ class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict} return {"required": arg_dict}
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb."
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["patch_embedding."] = argument
arg_dict["time_embedding."] = argument
arg_dict["time_projection."] = argument
arg_dict["text_embedding."] = argument
arg_dict["img_emb."] = argument
for i in range(40):
arg_dict["blocks.{}.".format(i)] = argument
arg_dict["head."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1, "ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
@@ -256,4 +280,5 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeLTXV": ModelMergeLTXV, "ModelMergeLTXV": ModelMergeLTXV,
"ModelMergeCosmos7B": ModelMergeCosmos7B, "ModelMergeCosmos7B": ModelMergeCosmos7B,
"ModelMergeCosmos14B": ModelMergeCosmos14B, "ModelMergeCosmos14B": ModelMergeCosmos14B,
"ModelMergeWAN2_1": ModelMergeWAN2_1,
} }

View File

@@ -2,6 +2,7 @@ import torch
import comfy.model_management import comfy.model_management
from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
import kornia.color
class Morphology: class Morphology:
@@ -40,8 +41,45 @@ class Morphology:
img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)
return (img_out,) return (img_out,)
class ImageRGBToYUV:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
}}
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
RETURN_NAMES = ("Y", "U", "V")
FUNCTION = "execute"
CATEGORY = "image/batch"
def execute(self, image):
out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1)
return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image))
class ImageYUVToRGB:
@classmethod
def INPUT_TYPES(s):
return {"required": {"Y": ("IMAGE",),
"U": ("IMAGE",),
"V": ("IMAGE",),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "image/batch"
def execute(self, Y, U, V):
image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1)
out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1)
return (out,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"Morphology": Morphology, "Morphology": Morphology,
"ImageRGBToYUV": ImageRGBToYUV,
"ImageYUVToRGB": ImageYUVToRGB,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {

View File

@@ -0,0 +1,79 @@
# Primitive nodes that are evaluated at backend.
from __future__ import annotations
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
class String(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.STRING, {})},
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: str) -> tuple[str]:
return (value,)
class Int(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.INT, {"control_after_generate": True})},
}
RETURN_TYPES = (IO.INT,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: int) -> tuple[int]:
return (value,)
class Float(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.FLOAT, {})},
}
RETURN_TYPES = (IO.FLOAT,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: float) -> tuple[float]:
return (value,)
class Boolean(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.BOOLEAN, {})},
}
RETURN_TYPES = (IO.BOOLEAN,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: bool) -> tuple[bool]:
return (value,)
NODE_CLASS_MAPPINGS = {
"PrimitiveString": String,
"PrimitiveInt": Int,
"PrimitiveFloat": Float,
"PrimitiveBoolean": Boolean,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PrimitiveString": "String",
"PrimitiveInt": "Int",
"PrimitiveFloat": "Float",
"PrimitiveBoolean": "Boolean",
}

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.26" __version__ = "0.3.27"

View File

@@ -93,7 +93,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
missing_keys = {} missing_keys = {}
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs) _, input_category, input_info = get_input_info(class_def, x, valid_inputs)
def mark_missing(): def mark_missing():
missing_keys[x] = True missing_keys[x] = True
input_data_all[x] = (None,) input_data_all[x] = (None,)
@@ -555,7 +555,7 @@ def validate_inputs(prompt, item, validated):
received_types = {} received_types = {}
for x in valid_inputs: for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs) input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
assert extra_info is not None assert extra_info is not None
if x not in inputs: if x not in inputs:
if input_category == "required": if input_category == "required":
@@ -571,7 +571,7 @@ def validate_inputs(prompt, item, validated):
continue continue
val = inputs[x] val = inputs[x]
info = (type_input, extra_info) info = (input_type, extra_info)
if isinstance(val, list): if isinstance(val, list):
if len(val) != 2: if len(val) != 2:
error = { error = {
@@ -592,8 +592,8 @@ def validate_inputs(prompt, item, validated):
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
received_type = r[val[1]] received_type = r[val[1]]
received_types[x] = received_type received_types[x] = received_type
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input): if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type):
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})" details = f"{x}, received_type({received_type}) mismatch input_type({input_type})"
error = { error = {
"type": "return_type_mismatch", "type": "return_type_mismatch",
"message": "Return type mismatch between linked nodes", "message": "Return type mismatch between linked nodes",
@@ -641,22 +641,22 @@ def validate_inputs(prompt, item, validated):
val = val["__value__"] val = val["__value__"]
inputs[x] = val inputs[x] = val
if type_input == "INT": if input_type == "INT":
val = int(val) val = int(val)
inputs[x] = val inputs[x] = val
if type_input == "FLOAT": if input_type == "FLOAT":
val = float(val) val = float(val)
inputs[x] = val inputs[x] = val
if type_input == "STRING": if input_type == "STRING":
val = str(val) val = str(val)
inputs[x] = val inputs[x] = val
if type_input == "BOOLEAN": if input_type == "BOOLEAN":
val = bool(val) val = bool(val)
inputs[x] = val inputs[x] = val
except Exception as ex: except Exception as ex:
error = { error = {
"type": "invalid_input_type", "type": "invalid_input_type",
"message": f"Failed to convert an input value to a {type_input} value", "message": f"Failed to convert an input value to a {input_type} value",
"details": f"{x}, {val}, {ex}", "details": f"{x}, {val}, {ex}",
"extra_info": { "extra_info": {
"input_name": x, "input_name": x,
@@ -696,18 +696,19 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if isinstance(type_input, list): if isinstance(input_type, list):
if val not in type_input: combo_options = input_type
if val not in combo_options:
input_config = info input_config = info
list_info = "" list_info = ""
# Don't send back gigantic lists like if they're lots of # Don't send back gigantic lists like if they're lots of
# scanned model filepaths # scanned model filepaths
if len(type_input) > 20: if len(combo_options) > 20:
list_info = f"(list of length {len(type_input)})" list_info = f"(list of length {len(combo_options)})"
input_config = None input_config = None
else: else:
list_info = str(type_input) list_info = str(combo_options)
error = { error = {
"type": "value_not_in_list", "type": "value_not_in_list",

View File

@@ -2264,6 +2264,9 @@ def init_builtin_extra_nodes():
"nodes_video.py", "nodes_video.py",
"nodes_lumina2.py", "nodes_lumina2.py",
"nodes_wan.py", "nodes_wan.py",
"nodes_lotus.py",
"nodes_hunyuan3d.py",
"nodes_primitive.py",
] ]
import_failed = [] import_failed = []

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.26" version = "0.3.27"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@@ -1,4 +1,4 @@
comfyui-frontend-package==1.12.14 comfyui-frontend-package==1.14.5
torch torch
torchsde torchsde
torchvision torchvision