Compare commits
7 Commits
v0.3.31
...
yoland68-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b545667469 | ||
|
|
d9c80a85e5 | ||
|
|
3e62c5513a | ||
|
|
cd18582578 | ||
|
|
80a44b97f5 | ||
|
|
9187a09483 | ||
|
|
3041e5c354 |
2
.github/workflows/test-build.yml
vendored
2
.github/workflows/test-build.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
||||
@@ -127,8 +127,8 @@ class CustomNodeManager:
|
||||
|
||||
if os.path.exists(workflows_dir):
|
||||
if folder_name != "example_workflows":
|
||||
logging.warning(
|
||||
"WARNING: Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
|
||||
logging.debug(
|
||||
"Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
|
||||
folder_name, module_name)
|
||||
|
||||
webapp.add_routes(
|
||||
|
||||
@@ -23,7 +23,6 @@ from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
@@ -37,11 +36,11 @@ def apply_rotary_pos_emb(
|
||||
return t_out
|
||||
|
||||
|
||||
def get_normalization(name: str, channels: int, weight_args={}):
|
||||
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
||||
if name == "I":
|
||||
return nn.Identity()
|
||||
elif name == "R":
|
||||
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
||||
return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
||||
else:
|
||||
raise ValueError(f"Normalization {name} not found")
|
||||
|
||||
@@ -120,15 +119,15 @@ class Attention(nn.Module):
|
||||
|
||||
self.to_q = nn.Sequential(
|
||||
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||
get_normalization(qkv_norm[0], norm_dim),
|
||||
get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
|
||||
)
|
||||
self.to_k = nn.Sequential(
|
||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||
get_normalization(qkv_norm[1], norm_dim),
|
||||
get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
|
||||
)
|
||||
self.to_v = nn.Sequential(
|
||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||
get_normalization(qkv_norm[2], norm_dim),
|
||||
get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
|
||||
@@ -27,8 +27,6 @@ from torchvision import transforms
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||
|
||||
from .blocks import (
|
||||
FinalLayer,
|
||||
GeneralDITTransformerBlock,
|
||||
@@ -195,7 +193,7 @@ class GeneralDIT(nn.Module):
|
||||
|
||||
if self.affline_emb_norm:
|
||||
logging.debug("Building affine embedding normalization layer")
|
||||
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
|
||||
self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
|
||||
else:
|
||||
self.affline_norm = nn.Identity()
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from comfy.ldm.modules.attention import optimized_attention
|
||||
from .layers import (
|
||||
FeedForward,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
TimestepEmbedder,
|
||||
)
|
||||
|
||||
@@ -90,10 +89,10 @@ class AsymmetricAttention(nn.Module):
|
||||
|
||||
# Query and key normalization for stability.
|
||||
assert qk_norm
|
||||
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
|
||||
# Output layers. y features go back down from dim_x -> dim_y.
|
||||
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
|
||||
|
||||
@@ -151,14 +151,3 @@ class PatchEmbed(nn.Module):
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
@@ -3,7 +3,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from torch.utils import checkpoint
|
||||
|
||||
@@ -51,7 +51,7 @@ class HunYuanDiTBlock(nn.Module):
|
||||
if norm_type == "layer":
|
||||
norm_layer = operations.LayerNorm
|
||||
elif norm_type == "rms":
|
||||
norm_layer = RMSNorm
|
||||
norm_layer = operations.RMSNorm
|
||||
else:
|
||||
raise ValueError(f"Unknown norm_type: {norm_type}")
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
@@ -64,8 +64,8 @@ class JointAttention(nn.Module):
|
||||
)
|
||||
|
||||
if qk_norm:
|
||||
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
||||
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
||||
self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
else:
|
||||
self.q_norm = self.k_norm = nn.Identity()
|
||||
|
||||
@@ -242,11 +242,11 @@ class JointTransformerBlock(nn.Module):
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
@@ -431,7 +431,7 @@ class NextDiT(nn.Module):
|
||||
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
||||
self.cap_embedder = nn.Sequential(
|
||||
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings),
|
||||
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
cap_feat_dim,
|
||||
dim,
|
||||
@@ -457,7 +457,7 @@ class NextDiT(nn.Module):
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
||||
|
||||
assert (dim // n_heads) == sum(axes_dims)
|
||||
|
||||
@@ -9,7 +9,6 @@ from einops import repeat
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
|
||||
@@ -49,8 +48,8 @@ class WanSelfAttention(nn.Module):
|
||||
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, freqs):
|
||||
r"""
|
||||
@@ -114,7 +113,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
||||
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, context, context_img_len):
|
||||
r"""
|
||||
|
||||
@@ -12,6 +12,46 @@ import torch
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
"""
|
||||
A container's `format` may be a comma-separated list of formats.
|
||||
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
||||
However, writing to a file/stream with `av.open` requires a single format,
|
||||
or `None` to auto-detect.
|
||||
"""
|
||||
if not container_format:
|
||||
return None # Auto-detect
|
||||
|
||||
if "," not in container_format:
|
||||
return container_format
|
||||
|
||||
formats = container_format.split(",")
|
||||
return formats[0]
|
||||
|
||||
|
||||
def get_open_write_kwargs(
|
||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||
) -> dict:
|
||||
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
||||
open_kwargs = {
|
||||
"mode": "w",
|
||||
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
||||
"options": {"movflags": "use_metadata_tags"},
|
||||
}
|
||||
|
||||
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
||||
if is_write_to_buffer:
|
||||
# Set output format explicitly, since it cannot be inferred from file extension
|
||||
if to_format == VideoContainer.AUTO:
|
||||
to_format = container_format.lower()
|
||||
elif isinstance(to_format, str):
|
||||
to_format = to_format.lower()
|
||||
open_kwargs["format"] = container_to_output_format(to_format)
|
||||
|
||||
return open_kwargs
|
||||
|
||||
|
||||
class VideoFromFile(VideoInput):
|
||||
"""
|
||||
Class representing video input from a file.
|
||||
@@ -89,7 +129,7 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
path: str | io.BytesIO,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
@@ -116,7 +156,9 @@ class VideoFromFile(VideoInput):
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
with av.open(path, mode='w', options={"movflags": "use_metadata_tags"}) as output_container:
|
||||
|
||||
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
||||
with av.open(path, **open_kwargs) as output_container:
|
||||
# Copy over the original metadata
|
||||
for key, value in container.metadata.items():
|
||||
if metadata is None or key not in metadata:
|
||||
@@ -211,7 +253,12 @@ class VideoFromComponents(VideoInput):
|
||||
start = i * samples_per_frame
|
||||
end = start + samples_per_frame
|
||||
# TODO(Feature) - Add support for stereo audio
|
||||
chunk = self.__components.audio['waveform'][0, 0, start:end].unsqueeze(0).numpy()
|
||||
chunk = (
|
||||
self.__components.audio["waveform"][0, 0, start:end]
|
||||
.unsqueeze(0)
|
||||
.contiguous()
|
||||
.numpy()
|
||||
)
|
||||
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
||||
audio_frame.sample_rate = audio_sample_rate
|
||||
audio_frame.pts = i * samples_per_frame
|
||||
|
||||
91
tests-unit/comfy_api_test/input_impl_test.py
Normal file
91
tests-unit/comfy_api_test/input_impl_test.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import io
|
||||
from comfy_api.input_impl.video_types import (
|
||||
container_to_output_format,
|
||||
get_open_write_kwargs,
|
||||
)
|
||||
from comfy_api.util import VideoContainer
|
||||
|
||||
|
||||
def test_container_to_output_format_empty_string():
|
||||
"""Test that an empty string input returns None. `None` arg allows default auto-detection."""
|
||||
assert container_to_output_format("") is None
|
||||
|
||||
|
||||
def test_container_to_output_format_none():
|
||||
"""Test that None input returns None."""
|
||||
assert container_to_output_format(None) is None
|
||||
|
||||
|
||||
def test_container_to_output_format_comma_separated():
|
||||
"""Test that a comma-separated list returns a valid singular format from the list."""
|
||||
comma_separated_format = "mp4,mov,m4a"
|
||||
output_format = container_to_output_format(comma_separated_format)
|
||||
assert output_format in comma_separated_format
|
||||
|
||||
|
||||
def test_container_to_output_format_single():
|
||||
"""Test that a single format string (not comma-separated list) is returned as is."""
|
||||
assert container_to_output_format("mp4") == "mp4"
|
||||
|
||||
|
||||
def test_get_open_write_kwargs_filepath_no_format():
|
||||
"""Test that 'format' kwarg is NOT set when dest is a file path."""
|
||||
kwargs_auto = get_open_write_kwargs("output.mp4", "mp4", VideoContainer.AUTO)
|
||||
assert "format" not in kwargs_auto, "Format should not be set for file paths (AUTO)"
|
||||
|
||||
kwargs_specific = get_open_write_kwargs("output.avi", "mp4", "avi")
|
||||
fail_msg = "Format should not be set for file paths (Specific)"
|
||||
assert "format" not in kwargs_specific, fail_msg
|
||||
|
||||
|
||||
def test_get_open_write_kwargs_base_options_mode():
|
||||
"""Test basic kwargs for file path: mode and movflags."""
|
||||
kwargs = get_open_write_kwargs("output.mp4", "mp4", VideoContainer.AUTO)
|
||||
assert kwargs["mode"] == "w", "mode should be set to write"
|
||||
|
||||
fail_msg = "movflags should be set to preserve custom metadata tags"
|
||||
assert "movflags" in kwargs["options"], fail_msg
|
||||
assert kwargs["options"]["movflags"] == "use_metadata_tags", fail_msg
|
||||
|
||||
|
||||
def test_get_open_write_kwargs_bytesio_auto_format():
|
||||
"""Test kwargs for BytesIO dest with AUTO format."""
|
||||
dest = io.BytesIO()
|
||||
container_fmt = "mov,mp4,m4a"
|
||||
kwargs = get_open_write_kwargs(dest, container_fmt, VideoContainer.AUTO)
|
||||
|
||||
assert kwargs["mode"] == "w"
|
||||
assert kwargs["options"]["movflags"] == "use_metadata_tags"
|
||||
|
||||
fail_msg = (
|
||||
"Format should be a valid format from the container's format list when AUTO"
|
||||
)
|
||||
assert kwargs["format"] in container_fmt, fail_msg
|
||||
|
||||
|
||||
def test_get_open_write_kwargs_bytesio_specific_format():
|
||||
"""Test kwargs for BytesIO dest with a specific single format."""
|
||||
dest = io.BytesIO()
|
||||
container_fmt = "avi"
|
||||
to_fmt = VideoContainer.MP4
|
||||
kwargs = get_open_write_kwargs(dest, container_fmt, to_fmt)
|
||||
|
||||
assert kwargs["mode"] == "w"
|
||||
assert kwargs["options"]["movflags"] == "use_metadata_tags"
|
||||
|
||||
fail_msg = "Format should be the specified format (lowercased) when output format is not AUTO"
|
||||
assert kwargs["format"] == "mp4", fail_msg
|
||||
|
||||
|
||||
def test_get_open_write_kwargs_bytesio_specific_format_list():
|
||||
"""Test kwargs for BytesIO dest with a specific comma-separated format."""
|
||||
dest = io.BytesIO()
|
||||
container_fmt = "avi"
|
||||
to_fmt = "mov,mp4,m4a" # A format string that is a list
|
||||
kwargs = get_open_write_kwargs(dest, container_fmt, to_fmt)
|
||||
|
||||
assert kwargs["mode"] == "w"
|
||||
assert kwargs["options"]["movflags"] == "use_metadata_tags"
|
||||
|
||||
fail_msg = "Format should be a valid format from the specified format list when output format is not AUTO"
|
||||
assert kwargs["format"] in to_fmt, fail_msg
|
||||
Reference in New Issue
Block a user