Compare commits

...

24 Commits

Author SHA1 Message Date
Yoland Y
d58ad2dd19 Expand supported image file extensions in LoadImageSetNode 2025-03-26 17:30:42 +08:00
Yoland Y
b87f55ed65 Move allow batch execution logic to different PR 2025-03-26 17:30:42 +08:00
Yoland Yan
0edc48af99 Remove empty spaces 2025-03-26 17:30:42 +08:00
Yoland Yan
bfc2f177e8 Refactor import statements in nodes_train.py
Reorganized and cleaned up import statements, removing unused imports and adding specific module imports for better clarity and organization.
2025-03-26 17:30:42 +08:00
Yoland Yan
f03ece18f2 Add remaining patch 2025-03-26 17:30:42 +08:00
Yoland Yan
2cd3c8a2fb Fix ruff errors 2025-03-26 17:30:41 +08:00
Yoland Yan
225a196dae Feat: Add basic LoRA training support
For more details: https://github.com/Comfy-Org/rfcs/pull/26
2025-03-26 17:30:41 +08:00
comfyanonymous
84fdaf7b0e Add CFGZeroStar node.
Works on all models that use a negative prompt but is meant for rectified
flow models.
2025-03-26 05:09:52 -04:00
comfyanonymous
8edc1f44c1 Support more float8 types. 2025-03-25 05:23:49 -04:00
comfyanonymous
eade1551bb Add Hunyuan3D to readme. 2025-03-24 07:14:32 -04:00
comfyanonymous
581a9991ff Add model merging node for WAN 2.1 2025-03-23 08:06:36 -04:00
comfyanonymous
e471c726e5 Fallback to pytorch attention if sage attention fails. 2025-03-22 15:45:56 -04:00
comfyanonymous
75c1c757d9 ComfyUI version v0.3.27 2025-03-21 20:09:54 -04:00
Chenlei Hu
ce9b084279 [nit] Format error strings (#7345) 2025-03-21 19:08:25 -04:00
Terry Jia
2206246055 support output normal and lineart once (#7290) 2025-03-21 16:24:13 -04:00
comfyanonymous
d9fa9d307f Automatically set the right sampling type for lotus. 2025-03-21 14:19:37 -04:00
thot experiment
83e839a89b Native LotusD Implementation (#7125)
* draft pass at a native comfy implementation of Lotus-D depth and normal est

* fix model_sampling kludges

* fix ruff

---------

Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
2025-03-21 14:04:15 -04:00
Chenlei Hu
0cf2274699 Update frontend to 1.14 (#7343) 2025-03-21 13:50:09 -04:00
comfyanonymous
0956107170 Nodes to convert images to YUV and back.
Can be used to convert an image to black and white.
2025-03-21 06:32:44 -04:00
Chenlei Hu
a4a956dbbd Add backend primitive nodes (#7328)
* Add backend primitive nodes

* Add control after generate to int primitive
2025-03-21 01:47:18 -04:00
Chenlei Hu
8b9ce4ed18 Update frontend to 1.13 (#7331) 2025-03-21 00:17:36 -04:00
comfyanonymous
3872b43d4b A few fixes for the hunyuan3d models. 2025-03-20 04:52:31 -04:00
comfyanonymous
32ca0805b7 Fix orientation of hunyuan 3d model. 2025-03-19 19:55:24 -04:00
comfyanonymous
11f1b41bab Initial Hunyuan3Dv2 implementation.
Supports the multiview, mini, turbo models and VAEs.
2025-03-19 16:52:58 -04:00
31 changed files with 2403 additions and 63 deletions

View File

@@ -69,6 +69,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- 3D Models
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.

View File

@@ -22,13 +22,21 @@ import app.logger
# The path to the requirements.txt file
req_path = Path(__file__).parents[1] / "requirements.txt"
def frontend_install_warning_message():
"""The warning message to display when the frontend version is not up to date."""
extra = ""
if sys.flags.no_user_site:
extra = "-s "
return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem"
return f"""
Please install the updated requirements.txt file by running:
{sys.executable} {extra}-m pip install -r {req_path}
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
""".strip()
def check_frontend_version():
@@ -43,7 +51,17 @@ def check_frontend_version():
with open(req_path, "r", encoding="utf-8") as f:
required_frontend = parse_version(f.readline().split("=")[-1])
if frontend_version < required_frontend:
app.logger.log_startup_warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message()))
app.logger.log_startup_warning(
f"""
________________________________________________________________________
WARNING WARNING WARNING WARNING WARNING
Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
{frontend_install_warning_message()}
________________________________________________________________________
""".strip()
)
else:
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
except Exception as e:
@@ -150,9 +168,20 @@ class FrontendManager:
def default_frontend_path(cls) -> str:
try:
import comfyui_frontend_package
return str(importlib.resources.files(comfyui_frontend_package) / "static")
except ImportError:
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n")
logging.error(
f"""
********** ERROR ***********
comfyui-frontend-package is not installed.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
)
sys.exit(-1)
@classmethod
@@ -175,7 +204,9 @@ class FrontendManager:
return match_result.group(1), match_result.group(2), match_result.group(3)
@classmethod
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
def init_frontend_unsafe(
cls, version_string: str, provider: Optional[FrontEndProvider] = None
) -> str:
"""
Initializes the frontend for the specified version.
@@ -197,12 +228,20 @@ class FrontendManager:
repo_owner, repo_name, version = cls.parse_version_string(version_string)
if version.startswith("v"):
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
expected_path = str(
Path(cls.CUSTOM_FRONTENDS_ROOT)
/ f"{repo_owner}_{repo_name}"
/ version.lstrip("v")
)
if os.path.exists(expected_path):
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
logging.info(
f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
)
return expected_path
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
logging.info(
f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
)
provider = provider or FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version)

View File

@@ -37,6 +37,8 @@ class IO(StrEnum):
CONTROL_NET = "CONTROL_NET"
VAE = "VAE"
MODEL = "MODEL"
LORA_MODEL = "LORA_MODEL"
LOSS_MAP = "LOSS_MAP"
CLIP_VISION = "CLIP_VISION"
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
STYLE_MODEL = "STYLE_MODEL"

View File

@@ -456,3 +456,13 @@ class Wan21(LatentFormat):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 0.9990943042622529
class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 1.0188137142395404

View File

@@ -797,12 +797,15 @@ class GeneralDITTransformerBlock(nn.Module):
adaln_lora_B_3D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for block in self.blocks:
x = block(
x,
emb_B_D,
crossattn_emb,
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
)
if self.training:
x = torch.utils.checkpoint.checkpoint(block, x, emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, use_reentrant=False)
else:
x = block(
x,
emb_B_D,
crossattn_emb,
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
)
return x

View File

@@ -0,0 +1,135 @@
import torch
from torch import nn
from comfy.ldm.flux.layers import (
DoubleStreamBlock,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
class Hunyuan3Dv2(nn.Module):
def __init__(
self,
in_channels=64,
context_in_dim=1536,
hidden_size=1024,
mlp_ratio=4.0,
num_heads=16,
depth=16,
depth_single_blocks=32,
qkv_bias=True,
guidance_embed=False,
image_model=None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.dtype = dtype
if hidden_size % num_heads != 0:
raise ValueError(
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
)
self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead
self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None
)
self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
dtype=dtype, device=device, operations=operations
)
for _ in range(depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
dtype=dtype, device=device, operations=operations
)
for _ in range(depth_single_blocks)
]
)
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
x = x.movedim(-1, -2)
timestep = 1.0 - timestep
txt = context
img = self.latent_in(x)
vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype))
if self.guidance_in is not None:
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype))
txt = self.cond_in(txt)
pe = None
attn_mask = None
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
img = img[:, txt.shape[1]:, ...]
img = self.final_layer(img, vec)
return img.movedim(-2, -1) * (-1.0)

587
comfy/ldm/hunyuan3d/vae.py Normal file
View File

@@ -0,0 +1,587 @@
# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py
# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, Tuple, List, Callable, Optional
import numpy as np
from einops import repeat, rearrange
from tqdm import tqdm
import logging
import comfy.ops
ops = comfy.ops.disable_weight_init
def generate_dense_grid_points(
bbox_min: np.ndarray,
bbox_max: np.ndarray,
octree_resolution: int,
indexing: str = "ij",
):
length = bbox_max - bbox_min
num_cells = octree_resolution
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
xyz = np.stack((xs, ys, zs), axis=-1)
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
return xyz, grid_size, length
class VanillaVolumeDecoder:
@torch.no_grad()
def __call__(
self,
latents: torch.FloatTensor,
geo_decoder: Callable,
bounds: Union[Tuple[float], List[float], float] = 1.01,
num_chunks: int = 10000,
octree_resolution: int = None,
enable_pbar: bool = True,
**kwargs,
):
device = latents.device
dtype = latents.dtype
batch_size = latents.shape[0]
# 1. generate query points
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
xyz_samples, grid_size, length = generate_dense_grid_points(
bbox_min=bbox_min,
bbox_max=bbox_max,
octree_resolution=octree_resolution,
indexing="ij"
)
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
# 2. latents to 3d volume
batch_logits = []
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
disable=not enable_pbar):
chunk_queries = xyz_samples[start: start + num_chunks, :]
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
logits = geo_decoder(queries=chunk_queries, latents=latents)
batch_logits.append(logits)
grid_logits = torch.cat(batch_logits, dim=1)
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
return grid_logits
class FourierEmbedder(nn.Module):
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
each feature dimension of `x[..., i]` into:
[
sin(x[..., i]),
sin(f_1*x[..., i]),
sin(f_2*x[..., i]),
...
sin(f_N * x[..., i]),
cos(x[..., i]),
cos(f_1*x[..., i]),
cos(f_2*x[..., i]),
...
cos(f_N * x[..., i]),
x[..., i] # only present if include_input is True.
], here f_i is the frequency.
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
Args:
num_freqs (int): the number of frequencies, default is 6;
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
input_dim (int): the input dimension, default is 3;
include_input (bool): include the input tensor or not, default is True.
Attributes:
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
otherwise, it is input_dim * num_freqs * 2.
"""
def __init__(self,
num_freqs: int = 6,
logspace: bool = True,
input_dim: int = 3,
include_input: bool = True,
include_pi: bool = True) -> None:
"""The initialization"""
super().__init__()
if logspace:
frequencies = 2.0 ** torch.arange(
num_freqs,
dtype=torch.float32
)
else:
frequencies = torch.linspace(
1.0,
2.0 ** (num_freqs - 1),
num_freqs,
dtype=torch.float32
)
if include_pi:
frequencies *= torch.pi
self.register_buffer("frequencies", frequencies, persistent=False)
self.include_input = include_input
self.num_freqs = num_freqs
self.out_dim = self.get_dims(input_dim)
def get_dims(self, input_dim):
temp = 1 if self.include_input or self.num_freqs == 0 else 0
out_dim = input_dim * (self.num_freqs * 2 + temp)
return out_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" Forward process.
Args:
x: tensor of shape [..., dim]
Returns:
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
where temp is 1 if include_input is True and 0 otherwise.
"""
if self.num_freqs > 0:
embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1)
if self.include_input:
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
else:
return torch.cat((embed.sin(), embed.cos()), dim=-1)
else:
return x
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
out = F.scaled_dot_product_attention(q, k, v)
return out
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if self.drop_prob == 0. or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and self.scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
def extra_repr(self):
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
class MLP(nn.Module):
def __init__(
self, *,
width: int,
expand_ratio: int = 4,
output_width: int = None,
drop_path_rate: float = 0.0
):
super().__init__()
self.width = width
self.c_fc = ops.Linear(width, width * expand_ratio)
self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width)
self.gelu = nn.GELU()
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
class QKVMultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
heads: int,
width=None,
qk_norm=False,
norm_layer=ops.LayerNorm
):
super().__init__()
self.heads = heads
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.attn_processor = CrossAttentionProcessor()
def forward(self, q, kv):
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
attn_ch = width // self.heads // 2
q = q.view(bs, n_ctx, self.heads, -1)
kv = kv.view(bs, n_data, self.heads, -1)
k, v = torch.split(kv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = self.attn_processor(self, q, k, v)
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
return out
class MultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
width: int,
heads: int,
qkv_bias: bool = True,
data_width: Optional[int] = None,
norm_layer=ops.LayerNorm,
qk_norm: bool = False,
kv_cache: bool = False,
):
super().__init__()
self.width = width
self.heads = heads
self.data_width = width if data_width is None else data_width
self.c_q = ops.Linear(width, width, bias=qkv_bias)
self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias)
self.c_proj = ops.Linear(width, width)
self.attention = QKVMultiheadCrossAttention(
heads=heads,
width=width,
norm_layer=norm_layer,
qk_norm=qk_norm
)
self.kv_cache = kv_cache
self.data = None
def forward(self, x, data):
x = self.c_q(x)
if self.kv_cache:
if self.data is None:
self.data = self.c_kv(data)
logging.info('Save kv cache,this should be called only once for one mesh')
data = self.data
else:
data = self.c_kv(data)
x = self.attention(x, data)
x = self.c_proj(x)
return x
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
*,
width: int,
heads: int,
mlp_expand_ratio: int = 4,
data_width: Optional[int] = None,
qkv_bias: bool = True,
norm_layer=ops.LayerNorm,
qk_norm: bool = False
):
super().__init__()
if data_width is None:
data_width = width
self.attn = MultiheadCrossAttention(
width=width,
heads=heads,
data_width=data_width,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm
)
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
def forward(self, x: torch.Tensor, data: torch.Tensor):
x = x + self.attn(self.ln_1(x), self.ln_2(data))
x = x + self.mlp(self.ln_3(x))
return x
class QKVMultiheadAttention(nn.Module):
def __init__(
self,
*,
heads: int,
width=None,
qk_norm=False,
norm_layer=ops.LayerNorm
):
super().__init__()
self.heads = heads
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
def forward(self, qkv):
bs, n_ctx, width = qkv.shape
attn_ch = width // self.heads // 3
qkv = qkv.view(bs, n_ctx, self.heads, -1)
q, k, v = torch.split(qkv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
return out
class MultiheadAttention(nn.Module):
def __init__(
self,
*,
width: int,
heads: int,
qkv_bias: bool,
norm_layer=ops.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0
):
super().__init__()
self.width = width
self.heads = heads
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
self.c_proj = ops.Linear(width, width)
self.attention = QKVMultiheadAttention(
heads=heads,
width=width,
norm_layer=norm_layer,
qk_norm=qk_norm
)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
x = self.c_qkv(x)
x = self.attention(x)
x = self.drop_path(self.c_proj(x))
return x
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
*,
width: int,
heads: int,
qkv_bias: bool = True,
norm_layer=ops.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0,
):
super().__init__()
self.attn = MultiheadAttention(
width=width,
heads=heads,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
def forward(self, x: torch.Tensor):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self,
*,
width: int,
layers: int,
heads: int,
qkv_bias: bool = True,
norm_layer=ops.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
width=width,
heads=heads,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor):
for block in self.resblocks:
x = block(x)
return x
class CrossAttentionDecoder(nn.Module):
def __init__(
self,
*,
out_channels: int,
fourier_embedder: FourierEmbedder,
width: int,
heads: int,
mlp_expand_ratio: int = 4,
downsample_ratio: int = 1,
enable_ln_post: bool = True,
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary"
):
super().__init__()
self.enable_ln_post = enable_ln_post
self.fourier_embedder = fourier_embedder
self.downsample_ratio = downsample_ratio
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
if self.downsample_ratio != 1:
self.latents_proj = ops.Linear(width * downsample_ratio, width)
if self.enable_ln_post == False:
qk_norm = False
self.cross_attn_decoder = ResidualCrossAttentionBlock(
width=width,
mlp_expand_ratio=mlp_expand_ratio,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm
)
if self.enable_ln_post:
self.ln_post = ops.LayerNorm(width)
self.output_proj = ops.Linear(width, out_channels)
self.label_type = label_type
self.count = 0
def forward(self, queries=None, query_embeddings=None, latents=None):
if query_embeddings is None:
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
self.count += query_embeddings.shape[1]
if self.downsample_ratio != 1:
latents = self.latents_proj(latents)
x = self.cross_attn_decoder(query_embeddings, latents)
if self.enable_ln_post:
x = self.ln_post(x)
occ = self.output_proj(x)
return occ
class ShapeVAE(nn.Module):
def __init__(
self,
*,
embed_dim: int,
width: int,
heads: int,
num_decoder_layers: int,
geo_decoder_downsample_ratio: int = 1,
geo_decoder_mlp_expand_ratio: int = 4,
geo_decoder_ln_post: bool = True,
num_freqs: int = 8,
include_pi: bool = True,
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary",
drop_path_rate: float = 0.0,
scale_factor: float = 1.0,
):
super().__init__()
self.geo_decoder_ln_post = geo_decoder_ln_post
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
self.post_kl = ops.Linear(embed_dim, width)
self.transformer = Transformer(
width=width,
layers=num_decoder_layers,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
self.geo_decoder = CrossAttentionDecoder(
fourier_embedder=self.fourier_embedder,
out_channels=1,
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
downsample_ratio=geo_decoder_downsample_ratio,
enable_ln_post=self.geo_decoder_ln_post,
width=width // geo_decoder_downsample_ratio,
heads=heads // geo_decoder_downsample_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
label_type=label_type,
)
self.volume_decoder = VanillaVolumeDecoder()
self.scale_factor = scale_factor
def decode(self, latents, **kwargs):
latents = self.post_kl(latents.movedim(-2, -1))
latents = self.transformer(latents)
bounds = kwargs.get("bounds", 1.01)
num_chunks = kwargs.get("num_chunks", 8000)
octree_resolution = kwargs.get("octree_resolution", 256)
enable_pbar = kwargs.get("enable_pbar", True)
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
return grid_logits.movedim(-2, -1)
def encode(self, x):
return None

View File

@@ -471,7 +471,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
tensor_layout = "HND"
else:
b, _, dim_head = q.shape
dim_head //= heads
@@ -479,7 +479,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)
tensor_layout="NHD"
tensor_layout = "NHD"
if mask is not None:
# add a batch dimension if there isn't already one
@@ -489,7 +489,17 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
try:
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
if tensor_layout == "NHD":
q, k, v = map(
lambda t: t.transpose(1, 2),
(q, k, v),
)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
if tensor_layout == "HND":
if not skip_output_reshape:
out = (
@@ -740,7 +750,7 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
n = p(n, extra_options)
x += n
x = n + x
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
@@ -780,12 +790,12 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
n = p(n, extra_options)
x += n
x = n + x
if self.is_res:
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
x = x_skip + x
return x

View File

@@ -36,6 +36,7 @@ import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model
import comfy.model_management
import comfy.patcher_extension
@@ -58,6 +59,7 @@ class ModelType(Enum):
FLOW = 6
V_PREDICTION_CONTINUOUS = 7
FLUX = 8
IMG_TO_IMG = 9
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
@@ -88,6 +90,8 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.FLUX:
c = comfy.model_sampling.CONST
s = comfy.model_sampling.ModelSamplingFlux
elif model_type == ModelType.IMG_TO_IMG:
c = comfy.model_sampling.IMG_TO_IMG
class ModelSampling(s, c):
pass
@@ -139,6 +143,7 @@ class BaseModel(torch.nn.Module):
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1)
@@ -600,6 +605,19 @@ class SDXL_instructpix2pix(IP2P, SDXL):
else:
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
class Lotus(BaseModel):
def extra_conds(self, **kwargs):
out = {}
cross_attn = kwargs.get("cross_attn", None)
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
device = kwargs["device"]
task_emb = torch.tensor([1, 0]).float().to(device)
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)]).unsqueeze(0)
out['y'] = comfy.conds.CONDRegular(task_emb)
return out
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG, device=None):
super().__init__(model_config, model_type, device=device)
class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
@@ -1013,3 +1031,18 @@ class WAN21(BaseModel):
if clip_vision_output is not None:
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
return out
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
guidance = kwargs.get("guidance", 5.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out

View File

@@ -154,7 +154,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16
@@ -323,6 +323,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "t2v"
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
dit_config = {}
dit_config["image_model"] = "hunyuan3d2"
dit_config["in_channels"] = in_shape[1]
dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
dit_config["hidden_size"] = in_shape[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 16
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["qkv_bias"] = True
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -667,8 +682,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
supported_models = [LotusD, SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
for unet_config in supported_models:
matches = True

View File

@@ -46,6 +46,32 @@ cpu_state = CPUState.GPU
total_vram = 0
def get_supported_float8_types():
float8_types = []
try:
float8_types.append(torch.float8_e4m3fn)
except:
pass
try:
float8_types.append(torch.float8_e4m3fnuz)
except:
pass
try:
float8_types.append(torch.float8_e5m2)
except:
pass
try:
float8_types.append(torch.float8_e5m2fnuz)
except:
pass
try:
float8_types.append(torch.float8_e8m0fnu)
except:
pass
return float8_types
FLOAT8_TYPES = get_supported_float8_types()
xpu_available = False
torch_version = ""
try:
@@ -701,11 +727,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
return torch.float8_e5m2
fp8_dtype = None
try:
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
fp8_dtype = weight_dtype
except:
pass
if weight_dtype in FLOAT8_TYPES:
fp8_dtype = weight_dtype
if fp8_dtype is not None:
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive

View File

@@ -17,23 +17,26 @@
"""
from __future__ import annotations
from typing import Optional, Callable
import torch
import collections
import copy
import inspect
import logging
import uuid
import collections
import math
import uuid
from typing import Callable, Optional
import torch
import comfy.utils
import comfy.float
import comfy.model_management
import comfy.lora
import comfy.hooks
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
def string_to_seed(data):
crc = 0xFFFFFFFF

View File

@@ -69,6 +69,15 @@ class CONST:
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
return latent / (1.0 - sigma)
class X0(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
return model_output
class IMG_TO_IMG(X0):
def calculate_input(self, sigma, noise):
return noise
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None, zsnr=None):
super().__init__()

View File

@@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.hunyuan3d.vae
import yaml
import math
@@ -412,6 +413,17 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
self.latent_dim = 1
ln_post = "geo_decoder.ln_post.weight" in sd
inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -498,7 +510,7 @@ class VAE:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in):
def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid()
pixel_samples = None
try:
@@ -510,7 +522,7 @@ class VAE:
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
@@ -974,7 +986,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
def load_diffusion_model_state_dict(sd, model_options={}):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
Args:
sd (dict): State dictionary containing model weights and configuration
model_options (dict, optional): Additional options for model loading. Supports:
- dtype: Override model data type
- custom_operations: Custom model operations
- fp8_optimizations: Enable FP8 optimizations
Returns:
ModelPatcher: A wrapped model instance that handles device management and weight loading.
Returns None if the model configuration cannot be detected.
The function:
1. Detects and handles different model formats (regular, diffusers, mmdit)
2. Configures model dtype based on parameters and device capabilities
3. Handles weight conversion and device placement
4. Manages model optimization settings
5. Loads weights and returns a device-managed model instance
"""
dtype = model_options.get("dtype", None)
#Allow loading unets from checkpoint files

View File

@@ -506,6 +506,22 @@ class SDXL_instructpix2pix(SDXL):
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
class LotusD(SD20):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"use_temporal_attention": False,
"adm_in_channels": 4,
"in_channels": 4,
}
unet_extra_config = {
"num_classes": 'sequential'
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.Lotus(self, device=device)
class SD3(supported_models_base.BASE):
unet_config = {
"in_channels": 16,
@@ -959,6 +975,44 @@ class WAN21_I2V(WAN21_T2V):
out = model_base.WAN21(self, image_to_video=True, device=device)
return out
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
}
unet_extra_config = {}
sampling_settings = {
"multiplier": 1.0,
"shift": 1.0,
}
memory_usage_factor = 3.5
clip_vision_prefix = "conditioner.main_image_encoder.model."
vae_key_prefix = ["vae."]
latent_format = latent_formats.Hunyuan3Dv2
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Hunyuan3Dv2(self, device=device)
return out
def clip_target(self, state_dict={}):
return None
class Hunyuan3Dv2mini(Hunyuan3Dv2):
unet_config = {
"image_model": "hunyuan3d2",
"depth": 8,
}
latent_format = latent_formats.Hunyuan3Dv2mini
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
models += [SVD_img2vid]

45
comfy_extras/nodes_cfg.py Normal file
View File

@@ -0,0 +1,45 @@
import torch
# https://github.com/WeichenFan/CFG-Zero-star
def optimized_scale(positive, negative):
positive_flat = positive.reshape(positive.shape[0], -1)
negative_flat = negative.reshape(negative.shape[0], -1)
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
st_star = dot_product / squared_norm
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
class CFGZeroStar:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "advanced/guidance"
def patch(self, model):
m = model.clone()
def cfg_zero_star(args):
guidance_scale = args['cond_scale']
x = args['input']
cond_p = args['cond_denoised']
uncond_p = args['uncond_denoised']
out = args["denoised"]
alpha = optimized_scale(x - cond_p, x - uncond_p)
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
m.set_model_sampler_post_cfg_function(cfg_zero_star)
return (m, )
NODE_CLASS_MAPPINGS = {
"CFGZeroStar": CFGZeroStar
}

View File

@@ -0,0 +1,415 @@
import torch
import os
import json
import struct
import numpy as np
from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch
import folder_paths
import comfy.model_management
from comfy.cli_args import args
class EmptyLatentHunyuan3Dv2:
@classmethod
def INPUT_TYPES(s):
return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/3d"
def generate(self, resolution, batch_size):
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
return ({"samples": latent, "type": "hunyuan3dv2"}, )
class Hunyuan3Dv2Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, clip_vision_output):
embeds = clip_vision_output.last_hidden_state
positive = [[embeds, {}]]
negative = [[torch.zeros_like(embeds), {}]]
return (positive, negative)
class Hunyuan3Dv2ConditioningMultiView:
@classmethod
def INPUT_TYPES(s):
return {"required": {},
"optional": {"front": ("CLIP_VISION_OUTPUT",),
"left": ("CLIP_VISION_OUTPUT",),
"back": ("CLIP_VISION_OUTPUT",),
"right": ("CLIP_VISION_OUTPUT",), }}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, front=None, left=None, back=None, right=None):
all_embeds = [front, left, back, right]
out = []
pos_embeds = None
for i, e in enumerate(all_embeds):
if e is not None:
if pos_embeds is None:
pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4))
out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1))
embeds = torch.cat(out, dim=1)
positive = [[embeds, {}]]
negative = [[torch.zeros_like(embeds), {}]]
return (positive, negative)
class VOXEL:
def __init__(self, data):
self.data = data
class VAEDecodeHunyuan3D:
@classmethod
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT", ),
"vae": ("VAE", ),
"num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}),
"octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}),
}}
RETURN_TYPES = ("VOXEL",)
FUNCTION = "decode"
CATEGORY = "latent/3d"
def decode(self, vae, samples, num_chunks, octree_resolution):
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
return (voxels, )
def voxel_to_mesh(voxels, threshold=0.5, device=None):
if device is None:
device = torch.device("cpu")
voxels = voxels.to(device)
binary = (voxels > threshold).float()
padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0)
D, H, W = binary.shape
neighbors = torch.tensor([
[0, 0, 1],
[0, 0, -1],
[0, 1, 0],
[0, -1, 0],
[1, 0, 0],
[-1, 0, 0]
], device=device)
z, y, x = torch.meshgrid(
torch.arange(D, device=device),
torch.arange(H, device=device),
torch.arange(W, device=device),
indexing='ij'
)
voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
solid_mask = binary.flatten() > 0
solid_indices = voxel_indices[solid_mask]
corner_offsets = [
torch.tensor([
[0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1]
], device=device),
torch.tensor([
[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]
], device=device),
torch.tensor([
[0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1]
], device=device),
torch.tensor([
[0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0]
], device=device),
torch.tensor([
[1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0]
], device=device),
torch.tensor([
[0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0]
], device=device)
]
all_vertices = []
all_indices = []
vertex_count = 0
for face_idx, offset in enumerate(neighbors):
neighbor_indices = solid_indices + offset
padded_indices = neighbor_indices + 1
is_exposed = padded[
padded_indices[:, 0],
padded_indices[:, 1],
padded_indices[:, 2]
] == 0
if not is_exposed.any():
continue
exposed_indices = solid_indices[is_exposed]
corners = corner_offsets[face_idx].unsqueeze(0)
face_vertices = exposed_indices.unsqueeze(1) + corners
all_vertices.append(face_vertices.reshape(-1, 3))
num_faces = exposed_indices.shape[0]
face_indices = torch.arange(
vertex_count,
vertex_count + 4 * num_faces,
device=device
).reshape(-1, 4)
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 1], face_indices[:, 2]], dim=1))
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 3]], dim=1))
vertex_count += 4 * num_faces
if len(all_vertices) > 0:
vertices = torch.cat(all_vertices, dim=0)
faces = torch.cat(all_indices, dim=0)
else:
vertices = torch.zeros((1, 3))
faces = torch.zeros((1, 3))
v_min = 0
v_max = max(voxels.shape)
vertices = vertices - (v_min + v_max) / 2
scale = (v_max - v_min) / 2
if scale > 0:
vertices = vertices / scale
vertices = torch.fliplr(vertices)
return vertices, faces
class MESH:
def __init__(self, vertices, faces):
self.vertices = vertices
self.faces = faces
class VoxelToMeshBasic:
@classmethod
def INPUT_TYPES(s):
return {"required": {"voxel": ("VOXEL", ),
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MESH",)
FUNCTION = "decode"
CATEGORY = "3d"
def decode(self, voxel, threshold):
vertices = []
faces = []
for x in voxel.data:
v, f = voxel_to_mesh(x, threshold=threshold, device=None)
vertices.append(v)
faces.append(f)
return (MESH(torch.stack(vertices), torch.stack(faces)), )
def save_glb(vertices, faces, filepath, metadata=None):
"""
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
Parameters:
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
faces: torch.Tensor of shape (M, 4) or (M, 3) - The face indices (quad or triangle faces)
filepath: str - Output filepath (should end with .glb)
"""
# Convert tensors to numpy arrays
vertices_np = vertices.cpu().numpy().astype(np.float32)
faces_np = faces.cpu().numpy().astype(np.uint32)
vertices_buffer = vertices_np.tobytes()
indices_buffer = faces_np.tobytes()
def pad_to_4_bytes(buffer):
padding_length = (4 - (len(buffer) % 4)) % 4
return buffer + b'\x00' * padding_length
vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
indices_buffer_padded = pad_to_4_bytes(indices_buffer)
buffer_data = vertices_buffer_padded + indices_buffer_padded
vertices_byte_length = len(vertices_buffer)
vertices_byte_offset = 0
indices_byte_length = len(indices_buffer)
indices_byte_offset = len(vertices_buffer_padded)
gltf = {
"asset": {"version": "2.0", "generator": "ComfyUI"},
"buffers": [
{
"byteLength": len(buffer_data)
}
],
"bufferViews": [
{
"buffer": 0,
"byteOffset": vertices_byte_offset,
"byteLength": vertices_byte_length,
"target": 34962 # ARRAY_BUFFER
},
{
"buffer": 0,
"byteOffset": indices_byte_offset,
"byteLength": indices_byte_length,
"target": 34963 # ELEMENT_ARRAY_BUFFER
}
],
"accessors": [
{
"bufferView": 0,
"byteOffset": 0,
"componentType": 5126, # FLOAT
"count": len(vertices_np),
"type": "VEC3",
"max": vertices_np.max(axis=0).tolist(),
"min": vertices_np.min(axis=0).tolist()
},
{
"bufferView": 1,
"byteOffset": 0,
"componentType": 5125, # UNSIGNED_INT
"count": faces_np.size,
"type": "SCALAR"
}
],
"meshes": [
{
"primitives": [
{
"attributes": {
"POSITION": 0
},
"indices": 1,
"mode": 4 # TRIANGLES
}
]
}
],
"nodes": [
{
"mesh": 0
}
],
"scenes": [
{
"nodes": [0]
}
],
"scene": 0
}
if metadata is not None:
gltf["asset"]["extras"] = metadata
# Convert the JSON to bytes
gltf_json = json.dumps(gltf).encode('utf8')
def pad_json_to_4_bytes(buffer):
padding_length = (4 - (len(buffer) % 4)) % 4
return buffer + b' ' * padding_length
gltf_json_padded = pad_json_to_4_bytes(gltf_json)
# Create the GLB header
# Magic glTF
glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data))
# Create JSON chunk header (chunk type 0)
json_chunk_header = struct.pack('<II', len(gltf_json_padded), 0x4E4F534A) # "JSON" in little endian
# Create BIN chunk header (chunk type 1)
bin_chunk_header = struct.pack('<II', len(buffer_data), 0x004E4942) # "BIN\0" in little endian
# Write the GLB file
with open(filepath, 'wb') as f:
f.write(glb_header)
f.write(json_chunk_header)
f.write(gltf_json_padded)
f.write(bin_chunk_header)
f.write(buffer_data)
return filepath
class SaveGLB:
@classmethod
def INPUT_TYPES(s):
return {"required": {"mesh": ("MESH", ),
"filename_prefix": ("STRING", {"default": "mesh/ComfyUI"}), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, }
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "3d"
def save(self, mesh, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
results = []
metadata = {}
if not args.disable_metadata:
if prompt is not None:
metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
for i in range(mesh.vertices.shape[0]):
f = f"{filename}_{counter:05}_.glb"
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
results.append({
"filename": f,
"subfolder": subfolder,
"type": "output"
})
counter += 1
return {"ui": {"3d": results}}
NODE_CLASS_MAPPINGS = {
"EmptyLatentHunyuan3Dv2": EmptyLatentHunyuan3Dv2,
"Hunyuan3Dv2Conditioning": Hunyuan3Dv2Conditioning,
"Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView,
"VAEDecodeHunyuan3D": VAEDecodeHunyuan3D,
"VoxelToMeshBasic": VoxelToMeshBasic,
"SaveGLB": SaveGLB,
}

View File

@@ -21,8 +21,8 @@ class Load3D():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
RETURN_NAMES = ("image", "mask", "mesh_path")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart")
FUNCTION = "process"
EXPERIMENTAL = True
@@ -32,12 +32,16 @@ class Load3D():
def process(self, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
lineart_path = folder_paths.get_annotated_filepath(image['lineart'])
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
return output_image, output_mask, model_file,
return output_image, output_mask, model_file, normal_image, lineart_image
class Load3DAnimation():
@classmethod
@@ -55,8 +59,8 @@ class Load3DAnimation():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
RETURN_NAMES = ("image", "mask", "mesh_path")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal")
FUNCTION = "process"
EXPERIMENTAL = True
@@ -66,12 +70,14 @@ class Load3DAnimation():
def process(self, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
return output_image, output_mask, model_file,
return output_image, output_mask, model_file, normal_image
class Preview3D():
@classmethod

File diff suppressed because one or more lines are too long

View File

@@ -20,10 +20,6 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input
class X0(comfy.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
return model_output
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
original_timesteps = 50
@@ -56,7 +52,7 @@ class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm", "x0"],),
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],),
"zsnr": ("BOOLEAN", {"default": False}),
}}
@@ -77,7 +73,9 @@ class ModelSamplingDiscrete:
sampling_type = LCM
sampling_base = ModelSamplingDiscreteDistilled
elif sampling == "x0":
sampling_type = X0
sampling_type = comfy.model_sampling.X0
elif sampling == "img_to_img":
sampling_type = comfy.model_sampling.IMG_TO_IMG
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass

View File

@@ -244,6 +244,30 @@ class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb."
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["patch_embedding."] = argument
arg_dict["time_embedding."] = argument
arg_dict["time_projection."] = argument
arg_dict["text_embedding."] = argument
arg_dict["img_emb."] = argument
for i in range(40):
arg_dict["blocks.{}.".format(i)] = argument
arg_dict["head."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
@@ -256,4 +280,5 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeLTXV": ModelMergeLTXV,
"ModelMergeCosmos7B": ModelMergeCosmos7B,
"ModelMergeCosmos14B": ModelMergeCosmos14B,
"ModelMergeWAN2_1": ModelMergeWAN2_1,
}

View File

@@ -2,6 +2,7 @@ import torch
import comfy.model_management
from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
import kornia.color
class Morphology:
@@ -40,8 +41,45 @@ class Morphology:
img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)
return (img_out,)
class ImageRGBToYUV:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
}}
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
RETURN_NAMES = ("Y", "U", "V")
FUNCTION = "execute"
CATEGORY = "image/batch"
def execute(self, image):
out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1)
return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image))
class ImageYUVToRGB:
@classmethod
def INPUT_TYPES(s):
return {"required": {"Y": ("IMAGE",),
"U": ("IMAGE",),
"V": ("IMAGE",),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "image/batch"
def execute(self, Y, U, V):
image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1)
out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1)
return (out,)
NODE_CLASS_MAPPINGS = {
"Morphology": Morphology,
"ImageRGBToYUV": ImageRGBToYUV,
"ImageYUVToRGB": ImageYUVToRGB,
}
NODE_DISPLAY_NAME_MAPPINGS = {

View File

@@ -0,0 +1,79 @@
# Primitive nodes that are evaluated at backend.
from __future__ import annotations
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
class String(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.STRING, {})},
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: str) -> tuple[str]:
return (value,)
class Int(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.INT, {"control_after_generate": True})},
}
RETURN_TYPES = (IO.INT,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: int) -> tuple[int]:
return (value,)
class Float(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.FLOAT, {})},
}
RETURN_TYPES = (IO.FLOAT,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: float) -> tuple[float]:
return (value,)
class Boolean(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.BOOLEAN, {})},
}
RETURN_TYPES = (IO.BOOLEAN,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: bool) -> tuple[bool]:
return (value,)
NODE_CLASS_MAPPINGS = {
"PrimitiveString": String,
"PrimitiveInt": Int,
"PrimitiveFloat": Float,
"PrimitiveBoolean": Boolean,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PrimitiveString": "String",
"PrimitiveInt": "Int",
"PrimitiveFloat": "Float",
"PrimitiveBoolean": "Boolean",
}

646
comfy_extras/nodes_train.py Normal file
View File

@@ -0,0 +1,646 @@
import datetime
import json
import logging
import math
import os
import numpy as np
import safetensors
import torch
from PIL import Image, ImageDraw, ImageFont
from PIL.PngImagePlugin import PngInfo
import comfy.samplers
import comfy.utils
import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
from comfy.cli_args import args
from comfy.comfy_types.node_typing import IO
class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None):
self.loss_fn = loss_fn
self.optimizer = optimizer
self.loss_callback = loss_callback
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
self.optimizer.zero_grad()
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
latent = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(sigmas),
torch.zeros_like(noise, requires_grad=True),
latent_image,
False
)
# Ensure model is in training mode and computing gradients
denoised = model_wrap(noise, sigmas, **extra_args)
try:
loss = self.loss_fn(denoised, latent.clone())
except RuntimeError as e:
if "does not require grad and does not have a grad_fn" in str(e):
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
loss.backward()
logging.info(f"Current Training Loss: {loss.item():.6f}")
if self.loss_callback:
self.loss_callback(loss.item())
self.optimizer.step()
# torch.cuda.memory._dump_snapshot("trainn.pickle")
# torch.cuda.memory._record_memory_history(enabled=None)
return torch.zeros_like(latent_image)
class BiasDiff(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.bias = bias
def __call__(self, b):
return b + self.bias
def passive_memory_usage(self):
return self.bias.nelement() * self.bias.element_size()
def move_to(self, device):
self.to(device=device)
return self.passive_memory_usage()
class LoraDiff(torch.nn.Module):
def __init__(self, lora_down, lora_up):
super().__init__()
self.lora_down = lora_down
self.lora_up = lora_up
def __call__(self, w):
return w + (self.lora_up @ self.lora_down).reshape(w.shape)
def passive_memory_usage(self):
return self.lora_down.nelement() * self.lora_down.element_size() + self.lora_up.nelement() * self.lora_up.element_size()
def move_to(self, device):
self.to(device=device)
return self.passive_memory_usage()
def load_and_process_images(image_files, input_dir, resize_method="None"):
"""Utility function to load and process a list of images.
Args:
image_files: List of image filenames
input_dir: Base directory containing the images
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
Returns:
torch.Tensor: Batch of processed images
"""
if not image_files:
raise ValueError("No valid images found in input")
output_images = []
w, h = None, None
for file in image_files:
image_path = os.path.join(input_dir, file)
img = node_helpers.pillow(Image.open, image_path)
if img.mode == "I":
img = img.point(lambda i: i * (1 / 255))
img = img.convert("RGB")
if w is None and h is None:
w, h = img.size[0], img.size[1]
# Resize image to first image
if img.size[0] != w or img.size[1] != h:
if resize_method == "Stretch":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "Crop":
img = img.crop((0, 0, w, h))
elif resize_method == "Pad":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "None":
raise ValueError(
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
)
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)[None,]
output_images.append(img_tensor)
return torch.cat(output_images, dim=0)
class LoadImageSetNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": (
[
f
for f in os.listdir(folder_paths.get_input_directory())
if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"))
],
{"image_upload": True, "allow_batch": True},
)
},
"optional": {
"resize_method": (
["None", "Stretch", "Crop", "Pad"],
{"default": "None"},
),
},
}
INPUT_IS_LIST = True
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load_images"
CATEGORY = "loaders"
EXPERIMENTAL = True
DESCRIPTION = "Loads a batch of images from a directory for training."
@classmethod
def VALIDATE_INPUTS(s, images, resize_method):
filenames = images[0] if isinstance(images[0], list) else images
for image in filenames:
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
def load_images(self, input_files, resize_method):
input_dir = folder_paths.get_input_directory()
valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"]
image_files = [
f
for f in input_files
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
output_tensor = load_and_process_images(image_files, input_dir, resize_method)
return (output_tensor,)
class LoadImageSetFromFolderNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."})
},
"optional": {
"resize_method": (
["None", "Stretch", "Crop", "Pad"],
{"default": "None"},
),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load_images"
CATEGORY = "loaders"
EXPERIMENTAL = True
DESCRIPTION = "Loads a batch of images from a directory for training."
def load_images(self, folder, resize_method):
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
image_files = [
f
for f in os.listdir(sub_input_dir)
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method)
return (output_tensor,)
def draw_loss_graph(loss_map, steps):
width, height = 500, 300
img = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()]
prev_point = (0, height - int(scaled_loss[0] * height))
for i, l in enumerate(scaled_loss[1:], start=1):
x = int(i / (steps - 1) * width)
y = height - int(l * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
return img
class TrainLoraNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}),
"vae": (
IO.VAE,
{
"tooltip": "The VAE model to use for encoding images for training."
},
),
"positive": (
IO.CONDITIONING,
{"tooltip": "The positive conditioning to use for training."},
),
"image": (
IO.IMAGE,
{"tooltip": "The image or image batch to train the LoRA on."},
),
"batch_size": (
IO.INT,
{
"default": 1,
"min": 1,
"max": 10000,
"step": 1,
"tooltip": "The batch size to use for training.",
},
),
"steps": (
IO.INT,
{
"default": 50,
"min": 1,
"max": 1000,
"tooltip": "The number of steps to train the LoRA for.",
},
),
"learning_rate": (
IO.FLOAT,
{
"default": 0.0003,
"min": 0.0000001,
"max": 1.0,
"step": 0.00001,
"tooltip": "The learning rate to use for training.",
},
),
"rank": (
IO.INT,
{
"default": 8,
"min": 1,
"max": 128,
"tooltip": "The rank of the LoRA layers.",
},
),
"optimizer": (
["Adam", "AdamW", "SGD", "RMSprop"],
{
"default": "Adam",
"tooltip": "The optimizer to use for training.",
},
),
"loss_function": (
["MSE", "L1", "Huber", "SmoothL1"],
{
"default": "MSE",
"tooltip": "The loss function to use for training.",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
},
),
"training_dtype": (
["bf16", "fp32"],
{"default": "bf16", "tooltip": "The dtype to use for training."},
),
"existing_lora": (
folder_paths.get_filename_list("loras") + ["[None]"],
{
"default": "[None]",
"tooltip": "The existing LoRA to append to. Set to None for new LoRA.",
},
),
},
}
RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT)
RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps")
FUNCTION = "train"
CATEGORY = "training"
EXPERIMENTAL = True
def train(
self,
model,
vae,
positive,
image,
batch_size,
steps,
learning_rate,
rank,
optimizer,
loss_function,
seed,
training_dtype,
existing_lora,
):
num_images = image.shape[0]
indices = torch.randperm(num_images)[:batch_size]
batch_tensor = image[indices]
# Ensure we're not in inference mode when encoding
encoded = vae.encode(batch_tensor)
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
with torch.inference_mode(False):
lora_sd = {}
generator = torch.Generator()
generator.manual_seed(seed)
# Load existing LoRA weights if provided
existing_weights = {}
existing_steps = 0
if existing_lora != "[None]":
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
key = "{}.weight".format(n)
shape = m.weight.shape
if len(shape) >= 2:
in_dim = math.prod(shape[1:])
out_dim = shape[0]
# Check if we have existing weights for this layer
lora_up_key = "{}.lora_up.weight".format(n)
lora_down_key = "{}.lora_down.weight".format(n)
if existing_lora != "[None]" and (
lora_up_key in existing_weights
and lora_down_key in existing_weights
):
# Initialize with existing weights
lora_up = torch.nn.Parameter(
existing_weights[lora_up_key].to(dtype=dtype),
requires_grad=True,
)
lora_down = torch.nn.Parameter(
existing_weights[lora_down_key].to(dtype=dtype),
requires_grad=True,
)
else:
if existing_lora != "[None]":
logging.info(f"Warning: No existing weights found for {lora_up_key} or {lora_down_key}")
# Initialize new weights
lora_down = torch.nn.Parameter(
torch.zeros(
(
rank,
in_dim,
),
dtype=dtype,
),
requires_grad=True,
)
lora_up = torch.nn.Parameter(
torch.zeros((out_dim, rank), dtype=dtype),
requires_grad=True,
)
torch.nn.init.zeros_(lora_up)
torch.nn.init.kaiming_uniform_(
lora_down, a=math.sqrt(5), generator=generator
)
lora_sd[lora_up_key] = lora_up
lora_sd[lora_down_key] = lora_down
mp.add_weight_wrapper(key, LoraDiff(lora_down, lora_up))
else:
diff = torch.nn.Parameter(
torch.zeros(
m.weight.shape, dtype=dtype, requires_grad=True
)
)
mp.add_weight_wrapper(key, BiasDiff(diff))
lora_sd["{}.diff".format(n)] = diff
if hasattr(m, "bias") and m.bias is not None:
key = "{}.bias".format(n)
bias = torch.nn.Parameter(
torch.zeros(m.bias.shape, dtype=dtype, requires_grad=True)
)
lora_sd["{}.diff_b".format(n)] = bias
mp.add_weight_wrapper(key, BiasDiff(bias))
if optimizer == "Adam":
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
elif optimizer == "AdamW":
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
elif optimizer == "SGD":
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
elif optimizer == "RMSprop":
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
# Setup loss function based on selection
if loss_function == "MSE":
criterion = torch.nn.MSELoss()
elif loss_function == "L1":
criterion = torch.nn.L1Loss()
elif loss_function == "Huber":
criterion = torch.nn.HuberLoss()
elif loss_function == "SmoothL1":
criterion = torch.nn.SmoothL1Loss()
# Setup sampler and guider like in test script
loss_map = {"loss": []}
loss_callback = lambda loss: loss_map["loss"].append(loss)
train_sampler = TrainSampler(
criterion, optimizer, loss_callback=loss_callback
)
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced()
# yoland: this currently resize to the first image in the dataset
# Training loop
for step in range(steps):
# Generate random sigma
sigma = mp.model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
sigma = torch.tensor([sigma])
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
ss.sample(
noise, guider, train_sampler, sigma, {"samples": encoded.clone()}
)
return (mp, lora_sd, loss_map, steps + existing_steps)
class SaveLoRA:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora": (
IO.LORA_MODEL,
{
"tooltip": "The LoRA model to save. Do not use the model with LoRA layers."
},
),
"prefix": (
"STRING",
{
"default": "trained_lora",
"tooltip": "The prefix to use for the saved LoRA file.",
},
),
},
"optional": {
"steps": (
IO.INT,
{
"forceInput": True,
"tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
},
),
},
}
RETURN_TYPES = ()
FUNCTION = "save"
CATEGORY = "loaders"
EXPERIMENTAL = True
OUTPUT_NODE = True
def save(self, lora, prefix, steps=None):
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
if steps is None:
output_file = f"models/loras/{prefix}_{date}_lora.safetensors"
else:
output_file = f"models/loras/{prefix}_{steps}_steps_{date}_lora.safetensors"
safetensors.torch.save_file(lora, output_file)
return {}
class LossGraphNode:
def __init__(self):
self.output_dir = folder_paths.get_temp_directory()
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"loss": (IO.LOSS_MAP, {"default": {}}),
"filename_prefix": (IO.STRING, {"default": "loss_graph"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "plot_loss"
OUTPUT_NODE = True
CATEGORY = "training"
EXPERIMENTAL = True
DESCRIPTION = "Plots the loss graph and saves it to the output directory."
def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
loss_values = loss["loss"]
width, height = 500, 300
margin = 40
img = Image.new(
"RGB", (width + margin, height + margin), "white"
) # Extend canvas
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_values), max(loss_values)
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values]
steps = len(loss_values)
prev_point = (margin, height - int(scaled_loss[0] * height))
for i, l in enumerate(scaled_loss[1:], start=1):
x = margin + int(i / steps * width) # Scale X properly
y = height - int(l * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
draw.line(
[(margin, height), (width + margin, height)], fill="black", width=2
) # X-axis
font = None
try:
font = ImageFont.truetype("arial.ttf", 12)
except IOError:
font = ImageFont.load_default()
# Add axis labels
draw.text((5, height // 2), "Loss", font=font, fill="black")
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
# Add min/max loss values
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
draw.text(
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
img.save(
os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"),
pnginfo=metadata,
)
return {
"ui": {
"images": [
{
"filename": f"{filename_prefix}_{date}.png",
"subfolder": "",
"type": "temp",
}
]
}
}
NODE_CLASS_MAPPINGS = {
"TrainLoraNode": TrainLoraNode,
"SaveLoRANode": SaveLoRA,
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
"LossGraphNode": LossGraphNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"TrainLoraNode": "Train LoRA",
"SaveLoRANode": "Save LoRA Weights",
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
"LossGraphNode": "Plot Loss Graph",
}

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.26"
__version__ = "0.3.27"

View File

@@ -1,23 +1,34 @@
import sys
import copy
import logging
import threading
import heapq
import inspect
import logging
import sys
import threading
import time
import traceback
from enum import Enum
import inspect
from typing import List, Literal, NamedTuple, Optional
import torch
import nodes
import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
import nodes
from comfy_execution.caching import (
CacheKeySetID,
CacheKeySetInputSignature,
HierarchicalCache,
LRUCache,
)
from comfy_execution.graph import (
DynamicPrompt,
ExecutionBlocker,
ExecutionList,
get_input_info,
)
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input
class ExecutionResult(Enum):
SUCCESS = 0
FAILURE = 1

View File

@@ -272,6 +272,9 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
def get_full_path(folder_name: str, filename: str) -> str | None:
"""
Get the full path of a file in a folder, has to be a file
"""
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in folder_names_and_paths:
@@ -289,6 +292,9 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
"""
Get the full path of a file in a folder, has to be a file
"""
full_path = get_full_path(folder_name, filename)
if full_path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
@@ -390,3 +396,26 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
return full_output_folder, filename, counter, subfolder, filename_prefix
def get_input_subfolders() -> list[str]:
"""Returns a list of all subfolder paths in the input directory, recursively.
Returns:
List of folder paths relative to the input directory, excluding the root directory
"""
input_dir = get_input_directory()
folders = []
try:
if not os.path.exists(input_dir):
return []
for root, dirs, _ in os.walk(input_dir):
rel_path = os.path.relpath(root, input_dir)
if rel_path != ".": # Only include non-root directories
# Normalize path separators to forward slashes
folders.append(rel_path.replace(os.sep, '/'))
return sorted(folders)
except FileNotFoundError:
return []

View File

@@ -2229,6 +2229,7 @@ def init_builtin_extra_nodes():
"nodes_model_downscale.py",
"nodes_images.py",
"nodes_video_model.py",
"nodes_train.py",
"nodes_sag.py",
"nodes_perpneg.py",
"nodes_stable3d.py",
@@ -2264,6 +2265,10 @@ def init_builtin_extra_nodes():
"nodes_video.py",
"nodes_lumina2.py",
"nodes_wan.py",
"nodes_lotus.py",
"nodes_hunyuan3d.py",
"nodes_primitive.py",
"nodes_cfg.py",
]
import_failed = []

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.26"
version = "0.3.27"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -1,4 +1,4 @@
comfyui-frontend-package==1.12.14
comfyui-frontend-package==1.14.5
torch
torchsde
torchvision

View File

@@ -0,0 +1,51 @@
import pytest
import os
import tempfile
from folder_paths import get_input_subfolders, set_input_directory
@pytest.fixture(scope="module")
def mock_folder_structure():
with tempfile.TemporaryDirectory() as temp_dir:
# Create a nested folder structure
folders = [
"folder1",
"folder1/subfolder1",
"folder1/subfolder2",
"folder2",
"folder2/deep",
"folder2/deep/nested",
"empty_folder"
]
# Create the folders
for folder in folders:
os.makedirs(os.path.join(temp_dir, folder))
# Add some files to test they're not included
with open(os.path.join(temp_dir, "root_file.txt"), "w") as f:
f.write("test")
with open(os.path.join(temp_dir, "folder1", "test.txt"), "w") as f:
f.write("test")
set_input_directory(temp_dir)
yield temp_dir
def test_gets_all_folders(mock_folder_structure):
folders = get_input_subfolders()
expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2",
"folder2", "folder2/deep", "folder2/deep/nested", "empty_folder"]
assert sorted(folders) == sorted(expected)
def test_handles_nonexistent_input_directory():
with tempfile.TemporaryDirectory() as temp_dir:
nonexistent = os.path.join(temp_dir, "nonexistent")
set_input_directory(nonexistent)
assert get_input_subfolders() == []
def test_empty_input_directory():
with tempfile.TemporaryDirectory() as temp_dir:
set_input_directory(temp_dir)
assert get_input_subfolders() == [] # Empty since we don't include root