Compare commits

...

7 Commits

Author SHA1 Message Date
Yoland Yan
b545667469 Update test-build.yml 2025-05-06 02:42:31 -04:00
comfyanonymous
d9c80a85e5 This should not be a warning. (#7946) 2025-05-05 07:49:07 -04:00
Christian Byrne
3e62c5513a make audio chunks contiguous before encoding (#7942) 2025-05-04 23:27:23 -04:00
Christian Byrne
cd18582578 Support saving Comfy VIDEO type to buffer (#7939)
* get output format when saving to buffer

* add unit tests for writing to file or stream with correct fmt

* handle `to_format=None`

* fix formatting
2025-05-04 23:26:57 -04:00
comfyanonymous
80a44b97f5 Change lumina to native RMSNorm. (#7935) 2025-05-04 06:39:23 -04:00
comfyanonymous
9187a09483 Change cosmos and hydit models to use the native RMSNorm. (#7934) 2025-05-04 06:26:20 -04:00
comfyanonymous
3041e5c354 Switch mochi and wan modes to use pytorch RMSNorm. (#7925)
* Switch genmo model to native RMSNorm.

* Switch WAN to native RMSNorm.
2025-05-03 19:07:55 -04:00
11 changed files with 168 additions and 46 deletions

View File

@@ -18,7 +18,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python-version: ["3.10", "3.11", "3.12", "3.13"]
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}

View File

@@ -127,8 +127,8 @@ class CustomNodeManager:
if os.path.exists(workflows_dir): if os.path.exists(workflows_dir):
if folder_name != "example_workflows": if folder_name != "example_workflows":
logging.warning( logging.debug(
"WARNING: Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'", "Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
folder_name, module_name) folder_name, module_name)
webapp.add_routes( webapp.add_routes(

View File

@@ -23,7 +23,6 @@ from einops import rearrange, repeat
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
from torch import nn from torch import nn
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
@@ -37,11 +36,11 @@ def apply_rotary_pos_emb(
return t_out return t_out
def get_normalization(name: str, channels: int, weight_args={}): def get_normalization(name: str, channels: int, weight_args={}, operations=None):
if name == "I": if name == "I":
return nn.Identity() return nn.Identity()
elif name == "R": elif name == "R":
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args) return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
else: else:
raise ValueError(f"Normalization {name} not found") raise ValueError(f"Normalization {name} not found")
@@ -120,15 +119,15 @@ class Attention(nn.Module):
self.to_q = nn.Sequential( self.to_q = nn.Sequential(
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args), operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[0], norm_dim), get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
) )
self.to_k = nn.Sequential( self.to_k = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args), operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[1], norm_dim), get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
) )
self.to_v = nn.Sequential( self.to_v = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args), operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[2], norm_dim), get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
) )
self.to_out = nn.Sequential( self.to_out = nn.Sequential(

View File

@@ -27,8 +27,6 @@ from torchvision import transforms
from enum import Enum from enum import Enum
import logging import logging
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
from .blocks import ( from .blocks import (
FinalLayer, FinalLayer,
GeneralDITTransformerBlock, GeneralDITTransformerBlock,
@@ -195,7 +193,7 @@ class GeneralDIT(nn.Module):
if self.affline_emb_norm: if self.affline_emb_norm:
logging.debug("Building affine embedding normalization layer") logging.debug("Building affine embedding normalization layer")
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6) self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
else: else:
self.affline_norm = nn.Identity() self.affline_norm = nn.Identity()

View File

@@ -13,7 +13,6 @@ from comfy.ldm.modules.attention import optimized_attention
from .layers import ( from .layers import (
FeedForward, FeedForward,
PatchEmbed, PatchEmbed,
RMSNorm,
TimestepEmbedder, TimestepEmbedder,
) )
@@ -90,10 +89,10 @@ class AsymmetricAttention(nn.Module):
# Query and key normalization for stability. # Query and key normalization for stability.
assert qk_norm assert qk_norm
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype) self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype) self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype) self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype) self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
# Output layers. y features go back down from dim_x -> dim_y. # Output layers. y features go back down from dim_x -> dim_y.
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype) self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)

View File

@@ -151,14 +151,3 @@ class PatchEmbed(nn.Module):
x = self.norm(x) x = self.norm(x)
return x return x
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
self.register_parameter("bias", None)
def forward(self, x):
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)

View File

@@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import comfy.ops import comfy.ops
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from torch.utils import checkpoint from torch.utils import checkpoint
@@ -51,7 +51,7 @@ class HunYuanDiTBlock(nn.Module):
if norm_type == "layer": if norm_type == "layer":
norm_layer = operations.LayerNorm norm_layer = operations.LayerNorm
elif norm_type == "rms": elif norm_type == "rms":
norm_layer = RMSNorm norm_layer = operations.RMSNorm
else: else:
raise ValueError(f"Unknown norm_type: {norm_type}") raise ValueError(f"Unknown norm_type: {norm_type}")

View File

@@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import comfy.ldm.common_dit import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND
@@ -64,8 +64,8 @@ class JointAttention(nn.Module):
) )
if qk_norm: if qk_norm:
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings) self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings) self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
else: else:
self.q_norm = self.k_norm = nn.Identity() self.q_norm = self.k_norm = nn.Identity()
@@ -242,11 +242,11 @@ class JointTransformerBlock(nn.Module):
operation_settings=operation_settings, operation_settings=operation_settings,
) )
self.layer_id = layer_id self.layer_id = layer_id
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.modulation = modulation self.modulation = modulation
if modulation: if modulation:
@@ -431,7 +431,7 @@ class NextDiT(nn.Module):
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
self.cap_embedder = nn.Sequential( self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings), operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear( operation_settings.get("operations").Linear(
cap_feat_dim, cap_feat_dim,
dim, dim,
@@ -457,7 +457,7 @@ class NextDiT(nn.Module):
for layer_id in range(n_layers) for layer_id in range(n_layers)
] ]
) )
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
assert (dim // n_heads) == sum(axes_dims) assert (dim // n_heads) == sum(axes_dims)

View File

@@ -9,7 +9,6 @@ from einops import repeat
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
import comfy.ldm.common_dit import comfy.ldm.common_dit
import comfy.model_management import comfy.model_management
@@ -49,8 +48,8 @@ class WanSelfAttention(nn.Module):
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, freqs): def forward(self, x, freqs):
r""" r"""
@@ -114,7 +113,7 @@ class WanI2VCrossAttention(WanSelfAttention):
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
# self.alpha = nn.Parameter(torch.zeros((1, ))) # self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context, context_img_len): def forward(self, x, context, context_img_len):
r""" r"""

View File

@@ -12,6 +12,46 @@ import torch
from comfy_api.input import VideoInput from comfy_api.input import VideoInput
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:
"""
A container's `format` may be a comma-separated list of formats.
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
However, writing to a file/stream with `av.open` requires a single format,
or `None` to auto-detect.
"""
if not container_format:
return None # Auto-detect
if "," not in container_format:
return container_format
formats = container_format.split(",")
return formats[0]
def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict:
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
open_kwargs = {
"mode": "w",
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
"options": {"movflags": "use_metadata_tags"},
}
is_write_to_buffer = isinstance(dest, io.BytesIO)
if is_write_to_buffer:
# Set output format explicitly, since it cannot be inferred from file extension
if to_format == VideoContainer.AUTO:
to_format = container_format.lower()
elif isinstance(to_format, str):
to_format = to_format.lower()
open_kwargs["format"] = container_to_output_format(to_format)
return open_kwargs
class VideoFromFile(VideoInput): class VideoFromFile(VideoInput):
""" """
Class representing video input from a file. Class representing video input from a file.
@@ -89,7 +129,7 @@ class VideoFromFile(VideoInput):
def save_to( def save_to(
self, self,
path: str, path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO, format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO, codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None metadata: Optional[dict] = None
@@ -116,7 +156,9 @@ class VideoFromFile(VideoInput):
) )
streams = container.streams streams = container.streams
with av.open(path, mode='w', options={"movflags": "use_metadata_tags"}) as output_container:
open_kwargs = get_open_write_kwargs(path, container_format, format)
with av.open(path, **open_kwargs) as output_container:
# Copy over the original metadata # Copy over the original metadata
for key, value in container.metadata.items(): for key, value in container.metadata.items():
if metadata is None or key not in metadata: if metadata is None or key not in metadata:
@@ -211,7 +253,12 @@ class VideoFromComponents(VideoInput):
start = i * samples_per_frame start = i * samples_per_frame
end = start + samples_per_frame end = start + samples_per_frame
# TODO(Feature) - Add support for stereo audio # TODO(Feature) - Add support for stereo audio
chunk = self.__components.audio['waveform'][0, 0, start:end].unsqueeze(0).numpy() chunk = (
self.__components.audio["waveform"][0, 0, start:end]
.unsqueeze(0)
.contiguous()
.numpy()
)
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono') audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
audio_frame.sample_rate = audio_sample_rate audio_frame.sample_rate = audio_sample_rate
audio_frame.pts = i * samples_per_frame audio_frame.pts = i * samples_per_frame

View File

@@ -0,0 +1,91 @@
import io
from comfy_api.input_impl.video_types import (
container_to_output_format,
get_open_write_kwargs,
)
from comfy_api.util import VideoContainer
def test_container_to_output_format_empty_string():
"""Test that an empty string input returns None. `None` arg allows default auto-detection."""
assert container_to_output_format("") is None
def test_container_to_output_format_none():
"""Test that None input returns None."""
assert container_to_output_format(None) is None
def test_container_to_output_format_comma_separated():
"""Test that a comma-separated list returns a valid singular format from the list."""
comma_separated_format = "mp4,mov,m4a"
output_format = container_to_output_format(comma_separated_format)
assert output_format in comma_separated_format
def test_container_to_output_format_single():
"""Test that a single format string (not comma-separated list) is returned as is."""
assert container_to_output_format("mp4") == "mp4"
def test_get_open_write_kwargs_filepath_no_format():
"""Test that 'format' kwarg is NOT set when dest is a file path."""
kwargs_auto = get_open_write_kwargs("output.mp4", "mp4", VideoContainer.AUTO)
assert "format" not in kwargs_auto, "Format should not be set for file paths (AUTO)"
kwargs_specific = get_open_write_kwargs("output.avi", "mp4", "avi")
fail_msg = "Format should not be set for file paths (Specific)"
assert "format" not in kwargs_specific, fail_msg
def test_get_open_write_kwargs_base_options_mode():
"""Test basic kwargs for file path: mode and movflags."""
kwargs = get_open_write_kwargs("output.mp4", "mp4", VideoContainer.AUTO)
assert kwargs["mode"] == "w", "mode should be set to write"
fail_msg = "movflags should be set to preserve custom metadata tags"
assert "movflags" in kwargs["options"], fail_msg
assert kwargs["options"]["movflags"] == "use_metadata_tags", fail_msg
def test_get_open_write_kwargs_bytesio_auto_format():
"""Test kwargs for BytesIO dest with AUTO format."""
dest = io.BytesIO()
container_fmt = "mov,mp4,m4a"
kwargs = get_open_write_kwargs(dest, container_fmt, VideoContainer.AUTO)
assert kwargs["mode"] == "w"
assert kwargs["options"]["movflags"] == "use_metadata_tags"
fail_msg = (
"Format should be a valid format from the container's format list when AUTO"
)
assert kwargs["format"] in container_fmt, fail_msg
def test_get_open_write_kwargs_bytesio_specific_format():
"""Test kwargs for BytesIO dest with a specific single format."""
dest = io.BytesIO()
container_fmt = "avi"
to_fmt = VideoContainer.MP4
kwargs = get_open_write_kwargs(dest, container_fmt, to_fmt)
assert kwargs["mode"] == "w"
assert kwargs["options"]["movflags"] == "use_metadata_tags"
fail_msg = "Format should be the specified format (lowercased) when output format is not AUTO"
assert kwargs["format"] == "mp4", fail_msg
def test_get_open_write_kwargs_bytesio_specific_format_list():
"""Test kwargs for BytesIO dest with a specific comma-separated format."""
dest = io.BytesIO()
container_fmt = "avi"
to_fmt = "mov,mp4,m4a" # A format string that is a list
kwargs = get_open_write_kwargs(dest, container_fmt, to_fmt)
assert kwargs["mode"] == "w"
assert kwargs["options"]["movflags"] == "use_metadata_tags"
fail_msg = "Format should be a valid format from the specified format list when output format is not AUTO"
assert kwargs["format"] in to_fmt, fail_msg