Compare commits

...

59 Commits

Author SHA1 Message Date
bymyself
d65ad9940b add swagger validator workflow 2025-05-21 21:15:41 -07:00
bymyself
e8a92e4c9b Fix linting issues in API tests 2025-05-20 12:26:56 -07:00
bymyself
fa9688b1fb [docs] Add OpenAPI specification and test framework 2025-05-20 12:15:46 -07:00
comfyanonymous
87f9130778 Revert "This doesn't seem to be needed on chroma. (#8209)" (#8210)
This reverts commit 7e84bf5373.
2025-05-20 05:39:55 -04:00
comfyanonymous
7e84bf5373 This doesn't seem to be needed on chroma. (#8209) 2025-05-20 05:29:23 -04:00
filtered
4f3b50ba51 Update README ROCm text to match link (#8199)
- Follow-up on #8198
2025-05-19 16:40:55 -04:00
comfyanonymous
e930a387d6 Update AMD instructions in README. (#8198) 2025-05-19 04:58:41 -04:00
comfyanonymous
d8e5662822 Remove default delimiter. (#8183) 2025-05-18 04:12:12 -04:00
LaVie024
3d44a09812 Update nodes_string.py (#8173) 2025-05-18 04:11:11 -04:00
comfyanonymous
62690eddec Node to add pixel space noise to an image. (#8182) 2025-05-18 04:09:56 -04:00
Christian Byrne
05eb10b43a Validate video inputs (#8133)
* validate kling lip sync input video

* add tooltips

* update duration estimates

* decrease epsilon

* fix rebase error
2025-05-18 04:08:47 -04:00
Silver
f5e4e976f4 Add missing category for T5TokenizerOption (#8177)
Change it if you need to but it should at least have a category.
2025-05-18 02:59:06 -04:00
comfyanonymous
aee2908d03 Remove useless log. (#8166) 2025-05-17 06:27:34 -04:00
comfyanonymous
dc46db7aa4 Make ImagePadForOutpaint return a 3 channel mask. (#8157) 2025-05-16 15:15:55 -04:00
filtered
7046983d95 Remove Desktop versioning claim from README (#8155) 2025-05-16 10:45:36 -07:00
comfyanonymous
1c2d45d2b5 Fix typo in last PR. (#8144)
More robust model detection for future proofing.
2025-05-15 19:02:19 -04:00
George0726
c820ef950d Add Wan-FUN Camera Control models and Add WanCameraImageToVideo node (#8013)
* support wan camera models

* fix by ruff check

* change camera_condition type; make camera_condition optional

* support camera trajectory nodes

* fix camera direction

---------

Co-authored-by: Qirui Sun <sunqr0667@126.com>
2025-05-15 19:00:43 -04:00
comfyanonymous
6a2e4bb9e0 Remove old hack used to fix windows pytorch 2.4 on the portable. (#8139)
Not necessary anymore.
2025-05-15 08:21:47 -04:00
Christian Byrne
f1f9763b4c Add get_duration method to Comfy VIDEO type (#8122)
* get duration from VIDEO type

* video get_duration unit test

* fix Windows unit test: can't delete opened temp file
2025-05-15 00:11:41 -04:00
comfyanonymous
08368f8e00 Update comment on ROCm pytorch attention in README. (#8123) 2025-05-14 17:54:50 -04:00
Christian Byrne
f3ff5c40db don't retry if API returns task failure (#8111) 2025-05-14 01:28:30 -04:00
Christian Byrne
98ff01e148 Display progress and result URL directly on API nodes (#8102)
* [Luma] Print download URL of successful task result directly on nodes (#177)

[Veo] Print download URL of successful task result directly on nodes (#184)

[Recraft] Print download URL of successful task result directly on nodes (#183)

[Pixverse] Print download URL of successful task result directly on nodes (#182)

[Kling] Print download URL of successful task result directly on nodes (#181)

[MiniMax] Print progress text and download URL of successful task result directly on nodes (#179)

[Docs] Link to docs in `API_NODE` class property type annotation comment (#178)

[Ideogram] Print download URL of successful task result directly on nodes (#176)

[Kling] Print download URL of successful task result directly on nodes (#181)

[Veo] Print download URL of successful task result directly on nodes (#184)

[Recraft] Print download URL of successful task result directly on nodes (#183)

[Pixverse] Print download URL of successful task result directly on nodes (#182)

[MiniMax] Print progress text and download URL of successful task result directly on nodes (#179)

[Docs] Link to docs in `API_NODE` class property type annotation comment (#178)

[Luma] Print download URL of successful task result directly on nodes (#177)

[Ideogram] Print download URL of successful task result directly on nodes (#176)

Show output URL and progress text on Pika nodes (#168)

[BFL] Print download URL of successful task result directly on nodes (#175)

[OpenAI ] Print download URL of successful task result directly on nodes (#174)

* fix ruff errors

* fix 3.10 syntax error
2025-05-14 00:33:18 -04:00
thot experiment
bab836d88d rework client.py to be more robust, add logging of api requests (#7988)
* rework how errors are handled on the client side

* add logging to /temp

* fix ruff

* fix rebase, stupid vscode gui
2025-05-13 20:42:29 -04:00
comfyanonymous
4a9014e201 Hunyuan Custom initial untested implementation. (#8101) 2025-05-13 15:53:47 -04:00
thot experiment
8a7c894d54 fix negative momentum (#8100) 2025-05-13 10:50:32 -07:00
comfyanonymous
a814f2e8cc Fix issue with old pytorch RMSNorm. (#8095) 2025-05-13 07:54:28 -04:00
comfyanonymous
481732a0ed Support official ACE Step loras. (#8094) 2025-05-13 07:32:16 -04:00
Christian Byrne
2156ce9453 add comment about using api key in headless (#8082) 2025-05-12 23:06:44 -04:00
thot experiment
4136502b7a implement APG guidance (#8081)
* first pass at impementing AGP

* rename, cleanup code

* fix ruff

* fix modified cond to match ref impl better, support different cond arity
2025-05-12 21:10:24 -04:00
Terry Jia
9ad287ff20 add support to record video as output for 3d node (#7927)
* add support to record video as output for 3d node

* source format

* add support to record video for load3d animation node
2025-05-12 16:47:14 -04:00
Chenlei Hu
f5cacaeb14 Update frontend to v1.19 (#8076)
* Update frontend to v1.19

* Update requirements.txt
2025-05-12 16:47:02 -04:00
Terry Jia
b7ed5f57bd string node (#7952) 2025-05-12 16:29:32 -04:00
thot experiment
b4abca828e add opus and mp3 to audio output node (#8019)
* first pass at opus and mp3 as well as migrating flac to pyav

* minor mp3 encoding fix

* fix ruff

* delete dead code

* split out save audio to separate nodes per filetype

* fix ruff
2025-05-12 16:00:01 -04:00
comfyanonymous
158419f3a0 ComfyUI version 0.3.34 2025-05-12 15:58:28 -04:00
comfyanonymous
640c47e7de Fix torch warning about deprecated function. (#8075)
Drop support for torch versions below 2.2 on the audio VAEs.
2025-05-12 14:32:01 -04:00
Christian Byrne
31e9e36c94 remove aspect ratio from kling request (#8062) 2025-05-12 13:32:24 -04:00
comfyanonymous
577de83ca9 ACE VAE works in fp16. (#8055) 2025-05-11 04:58:00 -04:00
Christian Byrne
3535909eb8 Add support for Comfy API keys (#8041)
* Handle Comfy API key based authorizaton (#167)

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>

* Bump frontend version to include API key features (#170)

* bump templates version

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2025-05-10 22:10:58 -04:00
Christian Byrne
235d3901fc Add method to stream text to node UI (#8018)
* show text progress preview

* include node id in message
2025-05-10 20:40:02 -04:00
comfyanonymous
d42613686f Fix issue with fp8 ops on some models. (#8045)
_scaled_mm errors when an input is non contiguous.
2025-05-10 07:52:56 -04:00
Pam
1b3bf0a5da Fix res_multistep_ancestral sampler (#8030) 2025-05-09 20:14:13 -04:00
Christian Byrne
ae60b150e5 update node tooltips and validation (#8036) 2025-05-09 20:02:45 -04:00
blepping
42da274717 Use normal ComfyUI attention in ACE-Steps model (#8023)
* Use normal ComfyUI attention in ACE-Steps model

* Let optimized_attention handle output reshape for ACE
2025-05-09 13:51:02 -04:00
thot experiment
28f178a840 move SVG to core (#7982)
* move SVG to core

* fix workflow embedding w/ unicode characters
2025-05-09 13:46:34 -04:00
comfyanonymous
8ab15c863c Add --mmap-torch-files to enable use of mmap when loading ckpt/pt (#8021) 2025-05-09 04:52:47 -04:00
comfyanonymous
924d771e18 Add ACE Step to README. (#8005) 2025-05-08 08:40:57 -04:00
comfyanonymous
02a1b01aad ComfyUI version 0.3.33 2025-05-08 07:36:48 -04:00
comfyanonymous
a692c3cca4 Make ACE VAE tiling work. (#8004) 2025-05-08 07:25:45 -04:00
comfyanonymous
5d3cc85e13 Make japanese hiragana and katakana characters work with ACE. (#7997) 2025-05-08 03:32:36 -04:00
comfyanonymous
c7c025b8d1 Adjust memory estimation code for ACE VAE. (#7990) 2025-05-08 01:22:23 -04:00
comfyanonymous
fd08e39588 Make torchaudio not a hard requirement. (#7987)
Some platforms can't install it apparently so if it's not there it should
only break models that actually use it.
2025-05-07 21:37:12 -04:00
comfyanonymous
56b6ee6754 Detection code to make ltxv models without config work. (#7986) 2025-05-07 21:28:24 -04:00
comfyanonymous
cc33cd3422 Experimental lyrics strength for ACE. (#7984) 2025-05-07 19:22:07 -04:00
comfyanonymous
b9980592c4 Refuse to load api nodes on old pyav version. (#7981) 2025-05-07 17:27:16 -04:00
comfyanonymous
16417b40d9 Initial ACE-Step model implementation. (#7972) 2025-05-07 08:33:34 -04:00
comfyanonymous
271c9c5b9e Better mem estimation for the LTXV 13B model. (#7963) 2025-05-06 09:52:37 -04:00
comfyanonymous
a4e679765e Change chroma to use Flux shift. (#7961) 2025-05-06 09:00:01 -04:00
comfyanonymous
0cf2e46b17 ComfyUI version 0.3.32 2025-05-06 07:39:54 -04:00
comfyanonymous
094e9ef126 Add a way to disable api nodes: --disable-api-nodes (#7960) 2025-05-06 04:53:53 -04:00
79 changed files with 25605 additions and 616 deletions

View File

@@ -0,0 +1,49 @@
name: Validate OpenAPI
on:
push:
branches: [ master ]
paths:
- 'openapi.yaml'
pull_request:
branches: [ master ]
paths:
- 'openapi.yaml'
jobs:
openapi-check:
runs-on: ubuntu-latest
# Service containers to run with `runner-job`
services:
# Label used to access the service container
swagger-editor:
# Docker Hub image
image: swaggerapi/swagger-editor
ports:
# Maps port 8080 on service container to the host 80
- 80:8080
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Validate OpenAPI definition
uses: swaggerexpert/swagger-editor-validate@v1
with:
definition-file: openapi.yaml
swagger-editor-url: http://localhost/
default-timeout: 20000
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install test dependencies
run: |
pip install -r tests-api/requirements.txt
- name: Run OpenAPI spec validation tests
run: |
pytest tests-api/test_spec_validation.py -v

1
.gitignore vendored
View File

@@ -21,6 +21,5 @@ venv/
*.log
web_custom_versions/
.DS_Store
openapi.yaml
filtered-openapi.yaml
uv.lock

View File

@@ -69,9 +69,11 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- 3D Models
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
@@ -108,7 +110,6 @@ ComfyUI follows a weekly release cycle every Friday, with three interconnected r
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
- Builds a new release using the latest stable core version
- Version numbers match the core release (e.g., Desktop v1.7.0 uses Core v1.7.0)
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
- Weekly frontend updates are merged into the core repository
@@ -196,11 +197,11 @@ Put your VAE in: models/vae
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2.4```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3```
This is the command to install the nightly with ROCm 6.3 which might have some performance improvements:
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
### Intel GPUs (Windows and Linux)
@@ -300,7 +301,7 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt
### AMD ROCm Tips
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default.
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```

View File

@@ -142,12 +142,15 @@ class PerformanceFeature(enum.Enum):
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")

View File

@@ -235,7 +235,7 @@ class ComfyNodeABC(ABC):
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
API_NODE: Optional[bool]
"""Flags a node as an API node."""
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
@classmethod
@abstractmethod

View File

@@ -1277,6 +1277,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
phi1_fn = lambda t: torch.expm1(t) / t
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
old_sigma_down = None
old_denoised = None
uncond_denoised = None
def post_cfg_function(args):
@@ -1304,9 +1305,9 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
x = x + d * dt
else:
# Second order multistep method in https://arxiv.org/pdf/2308.02157
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1])
t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
h = t_next - t
c2 = (t_prev - t) / h
c2 = (t_prev - t_old) / h
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
@@ -1326,6 +1327,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
old_denoised = uncond_denoised
else:
old_denoised = denoised
old_sigma_down = sigma_down
return x
@torch.no_grad()

View File

@@ -466,3 +466,7 @@ class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 1.0188137142395404
class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2

761
comfy/ldm/ace/attention.py Normal file
View File

@@ -0,0 +1,761 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/attention.py
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union, Optional
import torch
import torch.nn.functional as F
from torch import nn
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
kv_heads: Optional[int] = None,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
processor=None,
out_dim: int = None,
out_context_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
is_causal: bool = False,
dtype=None, device=None, operations=None
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.use_bias = bias
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.is_causal = is_causal
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.only_cross_attention = only_cross_attention
if self.added_kv_proj_dim is None and self.only_cross_attention:
raise ValueError(
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
)
self.group_norm = None
self.spatial_norm = None
self.norm_q = None
self.norm_k = None
self.norm_cross = None
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
else:
self.to_k = None
self.to_v = None
self.added_proj_bias = added_proj_bias
if self.added_kv_proj_dim is not None:
self.add_k_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
self.add_v_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
if self.context_pre_only is not None:
self.add_q_proj = operations.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias, dtype=dtype, device=device)
else:
self.add_q_proj = None
self.add_k_proj = None
self.add_v_proj = None
if not self.pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device))
self.to_out.append(nn.Dropout(dropout))
else:
self.to_out = None
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
else:
self.to_add_out = None
self.norm_added_q = None
self.norm_added_k = None
self.processor = processor
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
class CustomLiteLAProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
def __init__(self):
self.kernel_func = nn.ReLU(inplace=False)
self.eps = 1e-15
self.pad_val = 1.0
def apply_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
hidden_states_len = hidden_states.shape[1]
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
if encoder_hidden_states is not None:
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = hidden_states.shape[0]
# `sample` projections.
dtype = hidden_states.dtype
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# `context` projections.
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# attention
if not attn.is_cross_attention:
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
else:
query = hidden_states
key = encoder_hidden_states
value = encoder_hidden_states
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
# RoPE需要 [B, H, S, D] 输入
# 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
# Apply query and key normalization if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if rotary_freqs_cis is not None:
query = self.apply_rotary_emb(query, rotary_freqs_cis)
if not attn.is_cross_attention:
key = self.apply_rotary_emb(key, rotary_freqs_cis)
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
# 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
query = query.permute(0, 1, 3, 2) # [B, H, D, S]
if attention_mask is not None:
# attention_mask: [B, S] -> [B, 1, S, 1]
attention_mask = attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S, 1]
query = query * attention_mask.permute(0, 1, 3, 2) # [B, H, S, D] * [B, 1, S, 1]
if not attn.is_cross_attention:
key = key * attention_mask # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
value = value * attention_mask.permute(0, 1, 3, 2) # 如果 value 是 [B, h, D, S]那么需调整mask以匹配S维度
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S_enc, 1]
# 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
value = value * encoder_attention_mask.permute(0, 1, 3, 2) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
query = self.kernel_func(query)
key = self.kernel_func(key)
query, key, value = query.float(), key.float(), value.float()
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
vk = torch.matmul(value, key)
hidden_states = torch.matmul(vk, query)
if hidden_states.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.float()
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
hidden_states = hidden_states.to(dtype)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.to(dtype)
# Split the attention outputs.
if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj:
hidden_states, encoder_hidden_states = (
hidden_states[:, : hidden_states_len],
hidden_states[:, hidden_states_len:],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if encoder_hidden_states is not None and context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if torch.get_autocast_gpu_dtype() == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return hidden_states, encoder_hidden_states
class CustomerAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def apply_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if rotary_freqs_cis is not None:
query = self.apply_rotary_emb(query, rotary_freqs_cis)
if not attn.is_cross_attention:
key = self.apply_rotary_emb(key, rotary_freqs_cis)
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
# attention_mask: N x S1
# encoder_attention_mask: N x S2
# cross attention 整合attention_mask和encoder_attention_mask
combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype)
elif not attn.is_cross_attention and attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = optimized_attention(
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
).to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
"""Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
if isinstance(x, (list, tuple)):
return list(x)
return [x for _ in range(repeat_time)]
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
"""Return tuple with min_len by repeating element at idx_repeat."""
# convert to list first
x = val2list(x)
# repeat elements if necessary
if len(x) > 0:
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
return tuple(x)
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
if isinstance(kernel_size, tuple):
return tuple([get_same_padding(ks) for ks in kernel_size])
else:
assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
return kernel_size // 2
class ConvLayer(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
padding: Union[int, None] = None,
use_bias=False,
norm=None,
act=None,
dtype=None, device=None, operations=None
):
super().__init__()
if padding is None:
padding = get_same_padding(kernel_size)
padding *= dilation
self.in_dim = in_dim
self.out_dim = out_dim
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.groups = groups
self.padding = padding
self.use_bias = use_bias
self.conv = operations.Conv1d(
in_dim,
out_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=use_bias,
device=device,
dtype=dtype
)
if norm is not None:
self.norm = operations.RMSNorm(out_dim, elementwise_affine=False, dtype=dtype, device=device)
else:
self.norm = None
if act is not None:
self.act = nn.SiLU(inplace=True)
else:
self.act = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
if self.norm:
x = self.norm(x)
if self.act:
x = self.act(x)
return x
class GLUMBConv(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
out_feature=None,
kernel_size=3,
stride=1,
padding: Union[int, None] = None,
use_bias=False,
norm=(None, None, None),
act=("silu", "silu", None),
dilation=1,
dtype=None, device=None, operations=None
):
out_feature = out_feature or in_features
super().__init__()
use_bias = val2tuple(use_bias, 3)
norm = val2tuple(norm, 3)
act = val2tuple(act, 3)
self.glu_act = nn.SiLU(inplace=False)
self.inverted_conv = ConvLayer(
in_features,
hidden_features * 2,
1,
use_bias=use_bias[0],
norm=norm[0],
act=act[0],
dtype=dtype,
device=device,
operations=operations,
)
self.depth_conv = ConvLayer(
hidden_features * 2,
hidden_features * 2,
kernel_size,
stride=stride,
groups=hidden_features * 2,
padding=padding,
use_bias=use_bias[1],
norm=norm[1],
act=None,
dilation=dilation,
dtype=dtype,
device=device,
operations=operations,
)
self.point_conv = ConvLayer(
hidden_features,
out_feature,
1,
use_bias=use_bias[2],
norm=norm[2],
act=act[2],
dtype=dtype,
device=device,
operations=operations,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.transpose(1, 2)
x = self.inverted_conv(x)
x = self.depth_conv(x)
x, gate = torch.chunk(x, 2, dim=1)
gate = self.glu_act(gate)
x = x * gate
x = self.point_conv(x)
x = x.transpose(1, 2)
return x
class LinearTransformerBlock(nn.Module):
"""
A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
use_adaln_single=True,
cross_attention_dim=None,
added_kv_proj_dim=None,
context_pre_only=False,
mlp_ratio=4.0,
add_cross_attention=False,
add_cross_attention_dim=None,
qk_norm=None,
dtype=None, device=None, operations=None
):
super().__init__()
self.norm1 = operations.RMSNorm(dim, elementwise_affine=False, eps=1e-6)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
added_kv_proj_dim=added_kv_proj_dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
qk_norm=qk_norm,
processor=CustomLiteLAProcessor2_0(),
dtype=dtype,
device=device,
operations=operations,
)
self.add_cross_attention = add_cross_attention
self.context_pre_only = context_pre_only
if add_cross_attention and add_cross_attention_dim is not None:
self.cross_attn = Attention(
query_dim=dim,
cross_attention_dim=add_cross_attention_dim,
added_kv_proj_dim=add_cross_attention_dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=context_pre_only,
bias=True,
qk_norm=qk_norm,
processor=CustomerAttnProcessor2_0(),
dtype=dtype,
device=device,
operations=operations,
)
self.norm2 = operations.RMSNorm(dim, 1e-06, elementwise_affine=False)
self.ff = GLUMBConv(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
use_bias=(True, True, False),
norm=(None, None, None),
act=("silu", "silu", None),
dtype=dtype,
device=device,
operations=operations,
)
self.use_adaln_single = use_adaln_single
if use_adaln_single:
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, dtype=dtype, device=device))
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: torch.FloatTensor = None,
encoder_attention_mask: torch.FloatTensor = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
temb: torch.FloatTensor = None,
):
N = hidden_states.shape[0]
# step 1: AdaLN single
if self.use_adaln_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
comfy.model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
if self.use_adaln_single:
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
# step 2: attention
if not self.add_cross_attention:
attn_output, encoder_hidden_states = self.attn(
hidden_states=norm_hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
)
else:
attn_output, _ = self.attn(
hidden_states=norm_hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=None,
)
if self.use_adaln_single:
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if self.add_cross_attention:
attn_output = self.cross_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
)
hidden_states = attn_output + hidden_states
# step 3: add norm
norm_hidden_states = self.norm2(hidden_states)
if self.use_adaln_single:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
# step 4: feed forward
ff_output = self.ff(norm_hidden_states)
if self.use_adaln_single:
ff_output = gate_mlp * ff_output
hidden_states = hidden_states + ff_output
return hidden_states

File diff suppressed because it is too large Load Diff

385
comfy/ldm/ace/model.py Normal file
View File

@@ -0,0 +1,385 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/ace_step_transformer.py
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, List, Union
import torch
from torch import nn
import comfy.model_management
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from .attention import LinearTransformerBlock, t2i_modulate
from .lyric_encoder import ConformerEncoder as LyricEncoder
def cross_norm(hidden_states, controlnet_input):
# input N x T x c
mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
return controlnet_input
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, dtype=None, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
class T2IFinalLayer(nn.Module):
"""
The final layer of Sana.
"""
def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True, dtype=dtype, device=device)
self.scale_shift_table = nn.Parameter(torch.empty(2, hidden_size, dtype=dtype, device=device))
self.out_channels = out_channels
self.patch_size = patch_size
def unpatchfy(
self,
hidden_states: torch.Tensor,
width: int,
):
# 4 unpatchify
new_height, new_width = 1, hidden_states.size(1)
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
).contiguous()
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
).contiguous()
if width > new_width:
output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
elif width < new_width:
output = output[:, :, :, :width]
return output
def forward(self, x, t, output_length):
shift, scale = (comfy.model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
# unpatchify
output = self.unpatchfy(x, output_length)
return output
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
height=16,
width=4096,
patch_size=(16, 1),
in_channels=8,
embed_dim=1152,
bias=True,
dtype=None, device=None, operations=None
):
super().__init__()
patch_size_h, patch_size_w = patch_size
self.early_conv_layers = nn.Sequential(
operations.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias, dtype=dtype, device=device),
operations.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True, dtype=dtype, device=device),
operations.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias, dtype=dtype, device=device)
)
self.patch_size = patch_size
self.height, self.width = height // patch_size_h, width // patch_size_w
self.base_size = self.width
def forward(self, latent):
# early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
latent = self.early_conv_layers(latent)
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
return latent
class ACEStepTransformer2DModel(nn.Module):
# _supports_gradient_checkpointing = True
def __init__(
self,
in_channels: Optional[int] = 8,
num_layers: int = 28,
inner_dim: int = 1536,
attention_head_dim: int = 64,
num_attention_heads: int = 24,
mlp_ratio: float = 4.0,
out_channels: int = 8,
max_position: int = 32768,
rope_theta: float = 1000000.0,
speaker_embedding_dim: int = 512,
text_embedding_dim: int = 768,
ssl_encoder_depths: List[int] = [9, 9],
ssl_names: List[str] = ["mert", "m-hubert"],
ssl_latent_dims: List[int] = [1024, 768],
lyric_encoder_vocab_size: int = 6681,
lyric_hidden_size: int = 1024,
patch_size: List[int] = [16, 1],
max_height: int = 16,
max_width: int = 4096,
audio_model=None,
dtype=None, device=None, operations=None
):
super().__init__()
self.dtype = dtype
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = inner_dim
self.out_channels = out_channels
self.max_position = max_position
self.patch_size = patch_size
self.rope_theta = rope_theta
self.rotary_emb = Qwen2RotaryEmbedding(
dim=self.attention_head_dim,
max_position_embeddings=self.max_position,
base=self.rope_theta,
dtype=dtype,
device=device,
)
# 2. Define input layers
self.in_channels = in_channels
self.num_layers = num_layers
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
LinearTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_ratio=mlp_ratio,
add_cross_attention=True,
add_cross_attention_dim=self.inner_dim,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(self.num_layers)
]
)
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim, dtype=dtype, device=device, operations=operations)
self.t_block = nn.Sequential(nn.SiLU(), operations.Linear(self.inner_dim, 6 * self.inner_dim, bias=True, dtype=dtype, device=device))
# speaker
self.speaker_embedder = operations.Linear(speaker_embedding_dim, self.inner_dim, dtype=dtype, device=device)
# genre
self.genre_embedder = operations.Linear(text_embedding_dim, self.inner_dim, dtype=dtype, device=device)
# lyric
self.lyric_embs = operations.Embedding(lyric_encoder_vocab_size, lyric_hidden_size, dtype=dtype, device=device)
self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0, dtype=dtype, device=device, operations=operations)
self.lyric_proj = operations.Linear(lyric_hidden_size, self.inner_dim, dtype=dtype, device=device)
projector_dim = 2 * self.inner_dim
self.projectors = nn.ModuleList([
nn.Sequential(
operations.Linear(self.inner_dim, projector_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(projector_dim, projector_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(projector_dim, ssl_dim, dtype=dtype, device=device),
) for ssl_dim in ssl_latent_dims
])
self.proj_in = PatchEmbed(
height=max_height,
width=max_width,
patch_size=patch_size,
embed_dim=self.inner_dim,
bias=True,
dtype=dtype,
device=device,
operations=operations,
)
self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels, dtype=dtype, device=device, operations=operations)
def forward_lyric_encoder(
self,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
out_dtype=None,
):
# N x T x D
lyric_embs = self.lyric_embs(lyric_token_idx, out_dtype=out_dtype)
prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
return prompt_prenet_out
def encode(
self,
encoder_text_hidden_states: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
lyrics_strength=1.0,
):
bs = encoder_text_hidden_states.shape[0]
device = encoder_text_hidden_states.device
# speaker embedding
encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
# genre embedding
encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
# lyric
encoder_lyric_hidden_states = self.forward_lyric_encoder(
lyric_token_idx=lyric_token_idx,
lyric_mask=lyric_mask,
out_dtype=encoder_text_hidden_states.dtype,
)
encoder_lyric_hidden_states *= lyrics_strength
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
encoder_hidden_mask = None
if text_attention_mask is not None:
speaker_mask = torch.ones(bs, 1, device=device)
encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
return encoder_hidden_states, encoder_hidden_mask
def decode(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_mask: torch.Tensor,
timestep: Optional[torch.Tensor],
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep)
hidden_states = self.proj_in(hidden_states)
# controlnet logic
if block_controlnet_hidden_states is not None:
control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
hidden_states = hidden_states + control_condi * controlnet_scale
# inner_hidden_states = []
rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
for index_block, block in enumerate(self.transformer_blocks):
hidden_states = block(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_hidden_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
temb=temb,
)
output = self.final_layer(hidden_states, embedded_timestep, output_length)
return output
def forward(
self,
x,
timestep,
attention_mask=None,
context: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
lyrics_strength=1.0,
**kwargs
):
hidden_states = x
encoder_text_hidden_states = context
encoder_hidden_states, encoder_hidden_mask = self.encode(
encoder_text_hidden_states=encoder_text_hidden_states,
text_attention_mask=text_attention_mask,
speaker_embeds=speaker_embeds,
lyric_token_idx=lyric_token_idx,
lyric_mask=lyric_mask,
lyrics_strength=lyrics_strength,
)
output_length = hidden_states.shape[-1]
output = self.decode(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_mask=encoder_hidden_mask,
timestep=timestep,
output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale,
)
return output

View File

@@ -0,0 +1,644 @@
# Rewritten from diffusers
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Union
import comfy.model_management
import comfy.ops
ops = comfy.ops.disable_weight_init
class RMSNorm(ops.RMSNorm):
def __init__(self, dim, eps=1e-5, elementwise_affine=True, bias=False):
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
if elementwise_affine:
self.bias = nn.Parameter(torch.empty(dim)) if bias else None
def forward(self, x):
x = super().forward(x)
if self.elementwise_affine:
if self.bias is not None:
x = x + comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device)
return x
def get_normalization(norm_type, num_features, num_groups=32, eps=1e-5):
if norm_type == "batch_norm":
return nn.BatchNorm2d(num_features)
elif norm_type == "group_norm":
return ops.GroupNorm(num_groups, num_features)
elif norm_type == "layer_norm":
return ops.LayerNorm(num_features)
elif norm_type == "rms_norm":
return RMSNorm(num_features, eps=eps, elementwise_affine=True, bias=True)
else:
raise ValueError(f"Unknown normalization type: {norm_type}")
def get_activation(activation_type):
if activation_type == "relu":
return nn.ReLU()
elif activation_type == "relu6":
return nn.ReLU6()
elif activation_type == "silu":
return nn.SiLU()
elif activation_type == "leaky_relu":
return nn.LeakyReLU(0.2)
else:
raise ValueError(f"Unknown activation type: {activation_type}")
class ResBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
norm_type: str = "batch_norm",
act_fn: str = "relu6",
) -> None:
super().__init__()
self.norm_type = norm_type
self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
self.conv1 = ops.Conv2d(in_channels, in_channels, 3, 1, 1)
self.conv2 = ops.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
self.norm = get_normalization(norm_type, out_channels)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.conv1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.norm_type == "rms_norm":
# move channel to the last dimension so we apply RMSnorm across channel dimension
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
else:
hidden_states = self.norm(hidden_states)
return hidden_states + residual
class SanaMultiscaleAttentionProjection(nn.Module):
def __init__(
self,
in_channels: int,
num_attention_heads: int,
kernel_size: int,
) -> None:
super().__init__()
channels = 3 * in_channels
self.proj_in = ops.Conv2d(
channels,
channels,
kernel_size,
padding=kernel_size // 2,
groups=channels,
bias=False,
)
self.proj_out = ops.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_out(hidden_states)
return hidden_states
class SanaMultiscaleLinearAttention(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_attention_heads: int = None,
attention_head_dim: int = 8,
mult: float = 1.0,
norm_type: str = "batch_norm",
kernel_sizes: tuple = (5,),
eps: float = 1e-15,
residual_connection: bool = False,
):
super().__init__()
self.eps = eps
self.attention_head_dim = attention_head_dim
self.norm_type = norm_type
self.residual_connection = residual_connection
num_attention_heads = (
int(in_channels // attention_head_dim * mult)
if num_attention_heads is None
else num_attention_heads
)
inner_dim = num_attention_heads * attention_head_dim
self.to_q = ops.Linear(in_channels, inner_dim, bias=False)
self.to_k = ops.Linear(in_channels, inner_dim, bias=False)
self.to_v = ops.Linear(in_channels, inner_dim, bias=False)
self.to_qkv_multiscale = nn.ModuleList()
for kernel_size in kernel_sizes:
self.to_qkv_multiscale.append(
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
)
self.nonlinearity = nn.ReLU()
self.to_out = ops.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
self.norm_out = get_normalization(norm_type, out_channels)
def apply_linear_attention(self, query, key, value):
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1)
scores = torch.matmul(value, key.transpose(-1, -2))
hidden_states = torch.matmul(scores, query)
hidden_states = hidden_states.to(dtype=torch.float32)
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
return hidden_states
def apply_quadratic_attention(self, query, key, value):
scores = torch.matmul(key.transpose(-1, -2), query)
scores = scores.to(dtype=torch.float32)
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
hidden_states = torch.matmul(value, scores.to(value.dtype))
return hidden_states
def forward(self, hidden_states):
height, width = hidden_states.shape[-2:]
if height * width > self.attention_head_dim:
use_linear_attention = True
else:
use_linear_attention = False
residual = hidden_states
batch_size, _, height, width = list(hidden_states.size())
original_dtype = hidden_states.dtype
hidden_states = hidden_states.movedim(1, -1)
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
hidden_states = torch.cat([query, key, value], dim=3)
hidden_states = hidden_states.movedim(-1, 1)
multi_scale_qkv = [hidden_states]
for block in self.to_qkv_multiscale:
multi_scale_qkv.append(block(hidden_states))
hidden_states = torch.cat(multi_scale_qkv, dim=1)
if use_linear_attention:
# for linear attention upcast hidden_states to float32
hidden_states = hidden_states.to(dtype=torch.float32)
hidden_states = hidden_states.reshape(batch_size, -1, 3 * self.attention_head_dim, height * width)
query, key, value = hidden_states.chunk(3, dim=2)
query = self.nonlinearity(query)
key = self.nonlinearity(key)
if use_linear_attention:
hidden_states = self.apply_linear_attention(query, key, value)
hidden_states = hidden_states.to(dtype=original_dtype)
else:
hidden_states = self.apply_quadratic_attention(query, key, value)
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
hidden_states = self.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.norm_type == "rms_norm":
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
else:
hidden_states = self.norm_out(hidden_states)
if self.residual_connection:
hidden_states = hidden_states + residual
return hidden_states
class EfficientViTBlock(nn.Module):
def __init__(
self,
in_channels: int,
mult: float = 1.0,
attention_head_dim: int = 32,
qkv_multiscales: tuple = (5,),
norm_type: str = "batch_norm",
) -> None:
super().__init__()
self.attn = SanaMultiscaleLinearAttention(
in_channels=in_channels,
out_channels=in_channels,
mult=mult,
attention_head_dim=attention_head_dim,
norm_type=norm_type,
kernel_sizes=qkv_multiscales,
residual_connection=True,
)
self.conv_out = GLUMBConv(
in_channels=in_channels,
out_channels=in_channels,
norm_type="rms_norm",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.attn(x)
x = self.conv_out(x)
return x
class GLUMBConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
expand_ratio: float = 4,
norm_type: str = None,
residual_connection: bool = True,
) -> None:
super().__init__()
hidden_channels = int(expand_ratio * in_channels)
self.norm_type = norm_type
self.residual_connection = residual_connection
self.nonlinearity = nn.SiLU()
self.conv_inverted = ops.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
self.conv_depth = ops.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
self.conv_point = ops.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
self.norm = None
if norm_type == "rms_norm":
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.residual_connection:
residual = hidden_states
hidden_states = self.conv_inverted(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv_depth(hidden_states)
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
hidden_states = hidden_states * self.nonlinearity(gate)
hidden_states = self.conv_point(hidden_states)
if self.norm_type == "rms_norm":
# move channel to the last dimension so we apply RMSnorm across channel dimension
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.residual_connection:
hidden_states = hidden_states + residual
return hidden_states
def get_block(
block_type: str,
in_channels: int,
out_channels: int,
attention_head_dim: int,
norm_type: str,
act_fn: str,
qkv_mutliscales: tuple = (),
):
if block_type == "ResBlock":
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
elif block_type == "EfficientViTBlock":
block = EfficientViTBlock(
in_channels,
attention_head_dim=attention_head_dim,
norm_type=norm_type,
qkv_multiscales=qkv_mutliscales
)
else:
raise ValueError(f"Block with {block_type=} is not supported.")
return block
class DCDownBlock2d(nn.Module):
def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None:
super().__init__()
self.downsample = downsample
self.factor = 2
self.stride = 1 if downsample else 2
self.group_size = in_channels * self.factor**2 // out_channels
self.shortcut = shortcut
out_ratio = self.factor**2
if downsample:
assert out_channels % out_ratio == 0
out_channels = out_channels // out_ratio
self.conv = ops.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=self.stride,
padding=1,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = self.conv(hidden_states)
if self.downsample:
x = F.pixel_unshuffle(x, self.factor)
if self.shortcut:
y = F.pixel_unshuffle(hidden_states, self.factor)
y = y.unflatten(1, (-1, self.group_size))
y = y.mean(dim=2)
hidden_states = x + y
else:
hidden_states = x
return hidden_states
class DCUpBlock2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
interpolate: bool = False,
shortcut: bool = True,
interpolation_mode: str = "nearest",
) -> None:
super().__init__()
self.interpolate = interpolate
self.interpolation_mode = interpolation_mode
self.shortcut = shortcut
self.factor = 2
self.repeats = out_channels * self.factor**2 // in_channels
out_ratio = self.factor**2
if not interpolate:
out_channels = out_channels * out_ratio
self.conv = ops.Conv2d(in_channels, out_channels, 3, 1, 1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.interpolate:
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
x = self.conv(x)
else:
x = self.conv(hidden_states)
x = F.pixel_shuffle(x, self.factor)
if self.shortcut:
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
y = F.pixel_shuffle(y, self.factor)
hidden_states = x + y
else:
hidden_states = x
return hidden_states
class Encoder(nn.Module):
def __init__(
self,
in_channels: int,
latent_channels: int,
attention_head_dim: int = 32,
block_type: str or tuple = "ResBlock",
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
downsample_block_type: str = "pixel_unshuffle",
out_shortcut: bool = True,
):
super().__init__()
num_blocks = len(block_out_channels)
if isinstance(block_type, str):
block_type = (block_type,) * num_blocks
if layers_per_block[0] > 0:
self.conv_in = ops.Conv2d(
in_channels,
block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
kernel_size=3,
stride=1,
padding=1,
)
else:
self.conv_in = DCDownBlock2d(
in_channels=in_channels,
out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
downsample=downsample_block_type == "pixel_unshuffle",
shortcut=False,
)
down_blocks = []
for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)):
down_block_list = []
for _ in range(num_layers):
block = get_block(
block_type[i],
out_channel,
out_channel,
attention_head_dim=attention_head_dim,
norm_type="rms_norm",
act_fn="silu",
qkv_mutliscales=qkv_multiscales[i],
)
down_block_list.append(block)
if i < num_blocks - 1 and num_layers > 0:
downsample_block = DCDownBlock2d(
in_channels=out_channel,
out_channels=block_out_channels[i + 1],
downsample=downsample_block_type == "pixel_unshuffle",
shortcut=True,
)
down_block_list.append(downsample_block)
down_blocks.append(nn.Sequential(*down_block_list))
self.down_blocks = nn.ModuleList(down_blocks)
self.conv_out = ops.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1)
self.out_shortcut = out_shortcut
if out_shortcut:
self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
if self.out_shortcut:
x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size))
x = x.mean(dim=2)
hidden_states = self.conv_out(hidden_states) + x
else:
hidden_states = self.conv_out(hidden_states)
return hidden_states
class Decoder(nn.Module):
def __init__(
self,
in_channels: int,
latent_channels: int,
attention_head_dim: int = 32,
block_type: str or tuple = "ResBlock",
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
norm_type: str or tuple = "rms_norm",
act_fn: str or tuple = "silu",
upsample_block_type: str = "pixel_shuffle",
in_shortcut: bool = True,
):
super().__init__()
num_blocks = len(block_out_channels)
if isinstance(block_type, str):
block_type = (block_type,) * num_blocks
if isinstance(norm_type, str):
norm_type = (norm_type,) * num_blocks
if isinstance(act_fn, str):
act_fn = (act_fn,) * num_blocks
self.conv_in = ops.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1)
self.in_shortcut = in_shortcut
if in_shortcut:
self.in_shortcut_repeats = block_out_channels[-1] // latent_channels
up_blocks = []
for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
up_block_list = []
if i < num_blocks - 1 and num_layers > 0:
upsample_block = DCUpBlock2d(
block_out_channels[i + 1],
out_channel,
interpolate=upsample_block_type == "interpolate",
shortcut=True,
)
up_block_list.append(upsample_block)
for _ in range(num_layers):
block = get_block(
block_type[i],
out_channel,
out_channel,
attention_head_dim=attention_head_dim,
norm_type=norm_type[i],
act_fn=act_fn[i],
qkv_mutliscales=qkv_multiscales[i],
)
up_block_list.append(block)
up_blocks.insert(0, nn.Sequential(*up_block_list))
self.up_blocks = nn.ModuleList(up_blocks)
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
self.conv_act = nn.ReLU()
self.conv_out = None
if layers_per_block[0] > 0:
self.conv_out = ops.Conv2d(channels, in_channels, 3, 1, 1)
else:
self.conv_out = DCUpBlock2d(
channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.in_shortcut:
x = hidden_states.repeat_interleave(
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
)
hidden_states = self.conv_in(hidden_states) + x
else:
hidden_states = self.conv_in(hidden_states)
for up_block in reversed(self.up_blocks):
hidden_states = up_block(hidden_states)
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderDC(nn.Module):
def __init__(
self,
in_channels: int = 2,
latent_channels: int = 8,
attention_head_dim: int = 32,
encoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
decoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
encoder_layers_per_block: Tuple[int] = (2, 2, 3, 3),
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3),
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
upsample_block_type: str = "interpolate",
downsample_block_type: str = "Conv",
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
decoder_act_fns: Union[str, Tuple[str]] = "silu",
scaling_factor: float = 0.41407,
) -> None:
super().__init__()
self.encoder = Encoder(
in_channels=in_channels,
latent_channels=latent_channels,
attention_head_dim=attention_head_dim,
block_type=encoder_block_types,
block_out_channels=encoder_block_out_channels,
layers_per_block=encoder_layers_per_block,
qkv_multiscales=encoder_qkv_multiscales,
downsample_block_type=downsample_block_type,
)
self.decoder = Decoder(
in_channels=in_channels,
latent_channels=latent_channels,
attention_head_dim=attention_head_dim,
block_type=decoder_block_types,
block_out_channels=decoder_block_out_channels,
layers_per_block=decoder_layers_per_block,
qkv_multiscales=decoder_qkv_multiscales,
norm_type=decoder_norm_types,
act_fn=decoder_act_fns,
upsample_block_type=upsample_block_type,
)
self.scaling_factor = scaling_factor
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Internal encoding function."""
encoded = self.encoder(x)
return encoded * self.scaling_factor
def decode(self, z: torch.Tensor) -> torch.Tensor:
# Scale the latents back
z = z / self.scaling_factor
decoded = self.decoder(z)
return decoded
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = self.encode(x)
return self.decode(z)

View File

@@ -0,0 +1,109 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py
import torch
from .autoencoder_dc import AutoencoderDC
import logging
try:
import torchaudio
except:
logging.warning("torchaudio missing, ACE model will be broken")
import torchvision.transforms as transforms
from .music_vocoder import ADaMoSHiFiGANV1
class MusicDCAE(torch.nn.Module):
def __init__(self, source_sample_rate=None, dcae_config={}, vocoder_config={}):
super(MusicDCAE, self).__init__()
self.dcae = AutoencoderDC(**dcae_config)
self.vocoder = ADaMoSHiFiGANV1(**vocoder_config)
if source_sample_rate is None:
self.source_sample_rate = 48000
else:
self.source_sample_rate = source_sample_rate
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
self.transform = transforms.Compose([
transforms.Normalize(0.5, 0.5),
])
self.min_mel_value = -11.0
self.max_mel_value = 3.0
self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
self.mel_chunk_size = 1024
self.time_dimention_multiple = 8
self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
self.scale_factor = 0.1786
self.shift_factor = -1.9091
def load_audio(self, audio_path):
audio, sr = torchaudio.load(audio_path)
return audio, sr
def forward_mel(self, audios):
mels = []
for i in range(len(audios)):
image = self.vocoder.mel_transform(audios[i])
mels.append(image)
mels = torch.stack(mels)
return mels
@torch.no_grad()
def encode(self, audios, audio_lengths=None, sr=None):
if audio_lengths is None:
audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
audio_lengths = audio_lengths.to(audios.device)
if sr is None:
sr = self.source_sample_rate
if sr != 44100:
audios = torchaudio.functional.resample(audios, sr, 44100)
max_audio_len = audios.shape[-1]
if max_audio_len % (8 * 512) != 0:
audios = torch.nn.functional.pad(audios, (0, 8 * 512 - max_audio_len % (8 * 512)))
mels = self.forward_mel(audios)
mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
mels = self.transform(mels)
latents = []
for mel in mels:
latent = self.dcae.encoder(mel.unsqueeze(0))
latents.append(latent)
latents = torch.cat(latents, dim=0)
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
latents = (latents - self.shift_factor) * self.scale_factor
return latents
# return latents, latent_lengths
@torch.no_grad()
def decode(self, latents, audio_lengths=None, sr=None):
latents = latents / self.scale_factor + self.shift_factor
pred_wavs = []
for latent in latents:
mels = self.dcae.decoder(latent.unsqueeze(0))
mels = mels * 0.5 + 0.5
mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
wav = self.vocoder.decode(mels[0]).squeeze(1)
if sr is not None:
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
wav = torchaudio.functional.resample(wav, 44100, sr)
# wav = resampler(wav)
else:
sr = 44100
pred_wavs.append(wav)
if audio_lengths is not None:
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
return torch.stack(pred_wavs)
# return sr, pred_wavs
def forward(self, audios, audio_lengths=None, sr=None):
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
return sr, pred_wavs, latents, latent_lengths

View File

@@ -0,0 +1,113 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py
import torch
import torch.nn as nn
from torch import Tensor
import logging
try:
from torchaudio.transforms import MelScale
except:
logging.warning("torchaudio missing, ACE model will be broken")
import comfy.model_management
class LinearSpectrogram(nn.Module):
def __init__(
self,
n_fft=2048,
win_length=2048,
hop_length=512,
center=False,
mode="pow2_sqrt",
):
super().__init__()
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.mode = mode
self.register_buffer("window", torch.hann_window(win_length))
def forward(self, y: Tensor) -> Tensor:
if y.ndim == 3:
y = y.squeeze(1)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(
(self.win_length - self.hop_length) // 2,
(self.win_length - self.hop_length + 1) // 2,
),
mode="reflect",
).squeeze(1)
dtype = y.dtype
spec = torch.stft(
y.float(),
self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=comfy.model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
center=self.center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.view_as_real(spec)
if self.mode == "pow2_sqrt":
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = spec.to(dtype)
return spec
class LogMelSpectrogram(nn.Module):
def __init__(
self,
sample_rate=44100,
n_fft=2048,
win_length=2048,
hop_length=512,
n_mels=128,
center=False,
f_min=0.0,
f_max=None,
):
super().__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.n_mels = n_mels
self.f_min = f_min
self.f_max = f_max or sample_rate // 2
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
self.mel_scale = MelScale(
self.n_mels,
self.sample_rate,
self.f_min,
self.f_max,
self.n_fft // 2 + 1,
"slaney",
"slaney",
)
def compress(self, x: Tensor) -> Tensor:
return torch.log(torch.clamp(x, min=1e-5))
def decompress(self, x: Tensor) -> Tensor:
return torch.exp(x)
def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
linear = self.spectrogram(x)
x = self.mel_scale(linear)
x = self.compress(x)
# print(x.shape)
if return_linear:
return x, self.compress(linear)
return x

View File

@@ -0,0 +1,538 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_vocoder.py
import torch
from torch import nn
from functools import partial
from math import prod
from typing import Callable, Tuple, List
import numpy as np
import torch.nn.functional as F
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
from .music_log_mel import LogMelSpectrogram
import comfy.model_management
import comfy.ops
ops = comfy.ops.disable_weight_init
def drop_path(
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
""" # noqa: E501
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob,3):0.3f}"
class LayerNorm(nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
""" # noqa: E501
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(
x, self.normalized_shape, comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device), comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device), self.eps
)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = comfy.model_management.cast_to(self.weight[:, None], dtype=x.dtype, device=x.device) * x + comfy.model_management.cast_to(self.bias[:, None], dtype=x.dtype, device=x.device)
return x
class ConvNeXtBlock(nn.Module):
r"""ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
kernel_size (int): Kernel size for depthwise conv. Default: 7.
dilation (int): Dilation for depthwise conv. Default: 1.
""" # noqa: E501
def __init__(
self,
dim: int,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-6,
mlp_ratio: float = 4.0,
kernel_size: int = 7,
dilation: int = 1,
):
super().__init__()
self.dwconv = ops.Conv1d(
dim,
dim,
kernel_size=kernel_size,
padding=int(dilation * (kernel_size - 1) / 2),
groups=dim,
) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = ops.Linear(
dim, int(mlp_ratio * dim)
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = ops.Linear(int(mlp_ratio * dim), dim)
self.gamma = (
nn.Parameter(torch.empty((dim)), requires_grad=False)
if layer_scale_init_value > 0
else None
)
self.drop_path = DropPath(
drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x, apply_residual: bool = True):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) * x
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
x = self.drop_path(x)
if apply_residual:
x = input + x
return x
class ParallelConvNeXtBlock(nn.Module):
def __init__(self, kernel_sizes: List[int], *args, **kwargs):
super().__init__()
self.blocks = nn.ModuleList(
[
ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
for kernel_size in kernel_sizes
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.stack(
[block(x, apply_residual=False) for block in self.blocks] + [x],
dim=1,
).sum(dim=1)
class ConvNeXtEncoder(nn.Module):
def __init__(
self,
input_channels=3,
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
drop_path_rate=0.0,
layer_scale_init_value=1e-6,
kernel_sizes: Tuple[int] = (7,),
):
super().__init__()
assert len(depths) == len(dims)
self.channel_layers = nn.ModuleList()
stem = nn.Sequential(
ops.Conv1d(
input_channels,
dims[0],
kernel_size=7,
padding=3,
padding_mode="replicate",
),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
)
self.channel_layers.append(stem)
for i in range(len(depths) - 1):
mid_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
ops.Conv1d(dims[i], dims[i + 1], kernel_size=1),
)
self.channel_layers.append(mid_layer)
block_fn = (
partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
if len(kernel_sizes) == 1
else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
)
self.stages = nn.ModuleList()
drop_path_rates = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
cur = 0
for i in range(len(depths)):
stage = nn.Sequential(
*[
block_fn(
dim=dims[i],
drop_path=drop_path_rates[cur + j],
layer_scale_init_value=layer_scale_init_value,
)
for j in range(depths[i])
]
)
self.stages.append(stage)
cur += depths[i]
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
for channel_layer, stage in zip(self.channel_layers, self.stages):
x = channel_layer(x)
x = stage(x)
return self.norm(x)
def get_padding(kernel_size, dilation=1):
return (kernel_size * dilation - dilation) // 2
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.convs1 = nn.ModuleList(
[
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs2 = nn.ModuleList(
[
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.silu(x)
xt = c1(xt)
xt = F.silu(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for conv in self.convs1:
remove_weight_norm(conv)
for conv in self.convs2:
remove_weight_norm(conv)
class HiFiGANGenerator(nn.Module):
def __init__(
self,
*,
hop_length: int = 512,
upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
resblock_dilation_sizes: Tuple[Tuple[int]] = (
(1, 3, 5), (1, 3, 5), (1, 3, 5)),
num_mels: int = 128,
upsample_initial_channel: int = 512,
use_template: bool = True,
pre_conv_kernel_size: int = 7,
post_conv_kernel_size: int = 7,
post_activation: Callable = partial(nn.SiLU, inplace=True),
):
super().__init__()
assert (
prod(upsample_rates) == hop_length
), f"hop_length must be {prod(upsample_rates)}"
self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
num_mels,
upsample_initial_channel,
pre_conv_kernel_size,
1,
padding=get_padding(pre_conv_kernel_size),
)
)
self.num_upsamples = len(upsample_rates)
self.num_kernels = len(resblock_kernel_sizes)
self.noise_convs = nn.ModuleList()
self.use_template = use_template
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
c_cur = upsample_initial_channel // (2 ** (i + 1))
self.ups.append(
torch.nn.utils.parametrizations.weight_norm(
ops.ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
if not use_template:
continue
if i + 1 < len(upsample_rates):
stride_f0 = np.prod(upsample_rates[i + 1:])
self.noise_convs.append(
ops.Conv1d(
1,
c_cur,
kernel_size=stride_f0 * 2,
stride=stride_f0,
padding=stride_f0 // 2,
)
)
else:
self.noise_convs.append(ops.Conv1d(1, c_cur, kernel_size=1))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
self.resblocks.append(ResBlock1(ch, k, d))
self.activation_post = post_activation()
self.conv_post = torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
ch,
1,
post_conv_kernel_size,
1,
padding=get_padding(post_conv_kernel_size),
)
)
def forward(self, x, template=None):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.silu(x, inplace=True)
x = self.ups[i](x)
if self.use_template:
x = x + self.noise_convs[i](template)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
for up in self.ups:
remove_weight_norm(up)
for block in self.resblocks:
block.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class ADaMoSHiFiGANV1(nn.Module):
def __init__(
self,
input_channels: int = 128,
depths: List[int] = [3, 3, 9, 3],
dims: List[int] = [128, 256, 384, 512],
drop_path_rate: float = 0.0,
kernel_sizes: Tuple[int] = (7,),
upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
resblock_dilation_sizes: Tuple[Tuple[int]] = (
(1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
num_mels: int = 512,
upsample_initial_channel: int = 1024,
use_template: bool = False,
pre_conv_kernel_size: int = 13,
post_conv_kernel_size: int = 13,
sampling_rate: int = 44100,
n_fft: int = 2048,
win_length: int = 2048,
hop_length: int = 512,
f_min: int = 40,
f_max: int = 16000,
n_mels: int = 128,
):
super().__init__()
self.backbone = ConvNeXtEncoder(
input_channels=input_channels,
depths=depths,
dims=dims,
drop_path_rate=drop_path_rate,
kernel_sizes=kernel_sizes,
)
self.head = HiFiGANGenerator(
hop_length=hop_length,
upsample_rates=upsample_rates,
upsample_kernel_sizes=upsample_kernel_sizes,
resblock_kernel_sizes=resblock_kernel_sizes,
resblock_dilation_sizes=resblock_dilation_sizes,
num_mels=num_mels,
upsample_initial_channel=upsample_initial_channel,
use_template=use_template,
pre_conv_kernel_size=pre_conv_kernel_size,
post_conv_kernel_size=post_conv_kernel_size,
)
self.sampling_rate = sampling_rate
self.mel_transform = LogMelSpectrogram(
sample_rate=sampling_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
)
self.eval()
@torch.no_grad()
def decode(self, mel):
y = self.backbone(mel)
y = self.head(y)
return y
@torch.no_grad()
def encode(self, x):
return self.mel_transform(x)
def forward(self, mel):
y = self.backbone(mel)
y = self.head(y)
return y

View File

@@ -75,16 +75,10 @@ class SnakeBeta(nn.Module):
return x
def WNConv1d(*args, **kwargs):
try:
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
except:
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
try:
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
except:
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
if activation == "elu":

View File

@@ -228,6 +228,7 @@ class HunyuanVideo(nn.Module):
y: Tensor,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
control=None,
transformer_options={},
) -> Tensor:
@@ -238,6 +239,14 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)
ref_latent = self.img_in(ref_latent)
img = torch.cat([ref_latent, img], dim=-2)
ref_latent_ids[..., 0] = -1
ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1])
img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2)
if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
@@ -313,6 +322,8 @@ class HunyuanVideo(nn.Module):
img[:, : img_len] += add
img = img[:, : img_len]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
@@ -324,7 +335,7 @@ class HunyuanVideo(nn.Module):
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
return img
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
def img_ids(self, x):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
@@ -334,7 +345,11 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
return out

View File

@@ -247,6 +247,60 @@ class VaceWanAttentionBlock(WanAttentionBlock):
return c_skip, c
class WanCamAdapter(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}):
super(WanCamAdapter, self).__init__()
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
# Convolution: reduce spatial dimensions by a factor
# of 2 (without overlap)
self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
# Residual blocks for feature extraction
self.residual_blocks = nn.Sequential(
*[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)]
)
def forward(self, x):
# Reshape to merge the frame dimension into batch
bs, c, f, h, w = x.size()
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
# Pixel Unshuffle operation
x_unshuffled = self.pixel_unshuffle(x)
# Convolution operation
x_conv = self.conv(x_unshuffled)
# Feature extraction with residual blocks
out = self.residual_blocks(x_conv)
# Reshape to restore original bf dimension
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
out = out.permute(0, 2, 1, 3, 4)
return out
class WanCamResidualBlock(nn.Module):
def __init__(self, dim, operation_settings={}):
super(WanCamResidualBlock, self).__init__()
self.conv1 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.relu = nn.ReLU(inplace=True)
self.conv2 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x):
residual = x
out = self.relu(self.conv1(x))
out = self.conv2(out)
out += residual
return out
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
@@ -637,3 +691,92 @@ class VaceWanModel(WanModel):
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
class CameraWanModel(WanModel):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
def __init__(self,
model_type='camera',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
image_model=None,
in_dim_control_adapter=24,
device=None,
dtype=None,
operations=None,
):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
def forward_orig(
self,
x,
t,
context,
clip_fea=None,
freqs=None,
camera_conditions = None,
transformer_options={},
**kwargs,
):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
if self.control_adapter is not None and camera_conditions is not None:
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
x = x + x_camera
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x

View File

@@ -286,6 +286,12 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
if isinstance(model, comfy.model_base.ACEStep):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
return key_map

View File

@@ -39,6 +39,7 @@ import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.ace.model
import comfy.model_management
import comfy.patcher_extension
@@ -923,6 +924,10 @@ class HunyuanVideo(BaseModel):
if guiding_frame_index is not None:
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
ref_latent = kwargs.get("ref_latent", None)
if ref_latent is not None:
out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent))
return out
def scale_latent_inpaint(self, latent_image, **kwargs):
@@ -1074,6 +1079,17 @@ class WAN21_Vace(WAN21):
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out
class WAN21_Camera(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
camera_conditions = kwargs.get("camera_conditions", None)
if camera_conditions is not None:
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
@@ -1111,7 +1127,7 @@ class HiDream(BaseModel):
return out
class Chroma(Flux):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
def extra_conds(self, **kwargs):
@@ -1121,3 +1137,22 @@ class Chroma(Flux):
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class ACEStep(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
noise = kwargs.get("noise", None)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
if cross_attn is not None:
out['lyric_token_idx'] = comfy.conds.CONDRegular(conditioning_lyrics)
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out

View File

@@ -222,10 +222,39 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxv"
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
dit_config["attention_head_dim"] = shape[0] // 32
dit_config["cross_attention_dim"] = shape[1]
if metadata is not None and "config" in metadata:
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
return dit_config
if '{}genre_embedder.weight'.format(key_prefix) in state_dict_keys: #ACE-Step model
dit_config = {}
dit_config["audio_model"] = "ace"
dit_config["attention_head_dim"] = 128
dit_config["in_channels"] = 8
dit_config["inner_dim"] = 2560
dit_config["max_height"] = 16
dit_config["max_position"] = 32768
dit_config["max_width"] = 32768
dit_config["mlp_ratio"] = 2.5
dit_config["num_attention_heads"] = 20
dit_config["num_layers"] = 24
dit_config["out_channels"] = 8
dit_config["patch_size"] = [16, 1]
dit_config["rope_theta"] = 1000000.0
dit_config["speaker_embedding_dim"] = 512
dit_config["text_embedding_dim"] = 768
dit_config["ssl_encoder_depths"] = [8, 8]
dit_config["ssl_latent_dims"] = [1024, 768]
dit_config["ssl_names"] = ["mert", "m-hubert"]
dit_config["lyric_encoder_vocab_size"] = 6693
dit_config["lyric_hidden_size"] = 1024
return dit_config
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
patch_size = 2
dit_config = {}
@@ -332,6 +361,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "vace"
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "camera"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"

View File

@@ -308,10 +308,10 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
else:
scale_input = scale_input.to(input.device)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
if bias is not None:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)

View File

@@ -30,7 +30,7 @@ if RMSNorm is None:
def __init__(
self,
normalized_shape,
eps=None,
eps=1e-6,
elementwise_affine=True,
device=None,
dtype=None,

View File

@@ -15,6 +15,7 @@ import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import yaml
import math
@@ -42,6 +43,7 @@ import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.model_patcher
import comfy.lora
@@ -280,6 +282,7 @@ class VAE:
self.downscale_index_formula = None
self.upscale_index_formula = None
self.extra_1d_channel = None
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@@ -437,6 +440,20 @@ class VAE:
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = 8
self.output_channels = 2
self.upscale_ratio = 4096
self.downscale_ratio = 4096
self.latent_dim = 2
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -495,7 +512,13 @@ class VAE:
return output
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
else:
og_shape = samples.shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
@@ -515,9 +538,24 @@ class VAE:
samples /= 3.0
return samples
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio
else:
extra_channel_size = self.extra_1d_channel
out_channels = self.latent_channels * extra_channel_size
tile_x = tile_x // extra_channel_size
overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1:
return out
else:
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
@@ -542,7 +580,7 @@ class VAE:
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
dims = samples_in.ndim - 2
if dims == 1:
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
@@ -609,7 +647,7 @@ class VAE:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1:
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else:
samples = self.encode_tiled_(pixel_samples)
@@ -715,6 +753,7 @@ class CLIPType(Enum):
WAN = 13
HIDREAM = 14
CHROMA = 15
ACE = 16
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -840,8 +879,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
elif te_model == TEModel.T5_BASE:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
if clip_type == CLIPType.ACE or "spiece_model" in clip_data[0]:
clip_target.clip = comfy.text_encoders.ace.AceT5Model
clip_target.tokenizer = comfy.text_encoders.ace.AceT5Tokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
else:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer

View File

@@ -17,6 +17,7 @@ import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.ace
from . import supported_models_base
from . import latent_formats
@@ -785,6 +786,10 @@ class LTXV(supported_models_base.BASE):
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("cross_attention_dim", 2048) / 2048) * 5.5
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXV(self, device=device)
return out
@@ -987,6 +992,16 @@ class WAN21_FunControl2V(WAN21_T2V):
out = model_base.WAN21(self, image_to_video=False, device=device)
return out
class WAN21_Camera(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "camera",
"in_dim": 32,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
return out
class WAN21_Vace(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@@ -1096,6 +1111,34 @@ class Chroma(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma]
class ACEStep(supported_models_base.BASE):
unet_config = {
"audio_model": "ace",
}
unet_extra_config = {
}
sampling_settings = {
"shift": 3.0,
}
latent_format = comfy.latent_formats.ACEAudio
memory_usage_factor = 0.5
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.ACEStep(self, device=device)
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
models += [SVD_img2vid]

153
comfy/text_encoders/ace.py Normal file
View File

@@ -0,0 +1,153 @@
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.t5
import os
import re
import torch
import logging
from tokenizers import Tokenizer
from .ace_text_cleaners import multilingual_cleaners, japanese_to_romaji
SUPPORT_LANGUAGES = {
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
"pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
"nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
"ko": 6152, "hi": 6680
}
structure_pattern = re.compile(r"\[.*?\]")
DEFAULT_VOCAB_FILE = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
class VoiceBpeTokenizer:
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
self.tokenizer = None
if vocab_file is not None:
self.tokenizer = Tokenizer.from_file(vocab_file)
def preprocess_text(self, txt, lang):
txt = multilingual_cleaners(txt, lang)
return txt
def encode(self, txt, lang='en'):
# lang = lang.split("-")[0] # remove the region
# self.check_input_length(txt, lang)
txt = self.preprocess_text(txt, lang)
lang = "zh-cn" if lang == "zh" else lang
txt = f"[{lang}]{txt}"
txt = txt.replace(" ", "[SPACE]")
return self.tokenizer.encode(txt).ids
def get_lang(self, line):
if line.startswith("[") and line[3:4] == ']':
lang = line[1:3].lower()
if lang in SUPPORT_LANGUAGES:
return lang, line[4:]
return "en", line
def __call__(self, string):
lines = string.split("\n")
lyric_token_idx = [261]
for line in lines:
line = line.strip()
if not line:
lyric_token_idx += [2]
continue
lang, line = self.get_lang(line)
if lang not in SUPPORT_LANGUAGES:
lang = "en"
if "zh" in lang:
lang = "zh"
if "spa" in lang:
lang = "es"
try:
line_out = japanese_to_romaji(line)
if line_out != line:
lang = "ja"
line = line_out
except:
pass
try:
if structure_pattern.match(line):
token_idx = self.encode(line, "en")
else:
token_idx = self.encode(line, lang)
lyric_token_idx = lyric_token_idx + token_idx + [2]
except Exception as e:
logging.warning("tokenize error {} for line {} major_language {}".format(e, line, lang))
return {"input_ids": lyric_token_idx}
@staticmethod
def from_pretrained(path, **kwargs):
return VoiceBpeTokenizer(path, **kwargs)
def get_vocab(self):
return {}
class UMT5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_base.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)
class UMT5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=768, embedding_key='umt5base', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LyricsTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='lyrics', tokenizer_class=VoiceBpeTokenizer, has_start_token=True, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=2, has_end_token=False, tokenizer_data=tokenizer_data)
class AceT5Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.voicebpe = LyricsTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.umt5base = UMT5BaseTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["lyrics"] = self.voicebpe.tokenize_with_weights(kwargs.get("lyrics", ""), return_word_ids, **kwargs)
out["umt5base"] = self.umt5base.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
return self.umt5base.untokenize(token_weight_pair)
def state_dict(self):
return self.umt5base.state_dict()
class AceT5Model(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__()
self.umt5base = UMT5BaseModel(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set()
if dtype is not None:
self.dtypes.add(dtype)
def set_clip_options(self, options):
self.umt5base.set_clip_options(options)
def reset_clip_options(self):
self.umt5base.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_umt5base = token_weight_pairs["umt5base"]
token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
t5_out, t5_pooled = self.umt5base.encode_token_weights(token_weight_pairs_umt5base)
lyrics_embeds = torch.tensor(list(map(lambda a: a[0], token_weight_pairs_lyrics[0]))).unsqueeze(0)
return t5_out, None, {"conditioning_lyrics": lyrics_embeds}
def load_sd(self, sd):
return self.umt5base.load_sd(sd)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,395 @@
# basic text cleaners for the ACE step model
# I didn't copy the ones from the reference code because I didn't want to deal with the dependencies
# TODO: more languages than english?
import re
def japanese_to_romaji(japanese_text):
"""
Convert Japanese hiragana and katakana to romaji (Latin alphabet representation).
Args:
japanese_text (str): Text containing hiragana and/or katakana characters
Returns:
str: The romaji (Latin alphabet) equivalent
"""
# Dictionary mapping kana characters to their romaji equivalents
kana_map = {
# Katakana characters
'': 'a', '': 'i', '': 'u', '': 'e', '': 'o',
'': 'ka', '': 'ki', '': 'ku', '': 'ke', '': 'ko',
'': 'sa', '': 'shi', '': 'su', '': 'se', '': 'so',
'': 'ta', '': 'chi', '': 'tsu', '': 'te', '': 'to',
'': 'na', '': 'ni', '': 'nu', '': 'ne', '': 'no',
'': 'ha', '': 'hi', '': 'fu', '': 'he', '': 'ho',
'': 'ma', '': 'mi', '': 'mu', '': 'me', '': 'mo',
'': 'ya', '': 'yu', '': 'yo',
'': 'ra', '': 'ri', '': 'ru', '': 're', '': 'ro',
'': 'wa', '': 'wo', '': 'n',
# Katakana voiced consonants
'': 'ga', '': 'gi', '': 'gu', '': 'ge', '': 'go',
'': 'za', '': 'ji', '': 'zu', '': 'ze', '': 'zo',
'': 'da', '': 'ji', '': 'zu', '': 'de', '': 'do',
'': 'ba', '': 'bi', '': 'bu', '': 'be', '': 'bo',
'': 'pa', '': 'pi', '': 'pu', '': 'pe', '': 'po',
# Katakana combinations
'キャ': 'kya', 'キュ': 'kyu', 'キョ': 'kyo',
'シャ': 'sha', 'シュ': 'shu', 'ショ': 'sho',
'チャ': 'cha', 'チュ': 'chu', 'チョ': 'cho',
'ニャ': 'nya', 'ニュ': 'nyu', 'ニョ': 'nyo',
'ヒャ': 'hya', 'ヒュ': 'hyu', 'ヒョ': 'hyo',
'ミャ': 'mya', 'ミュ': 'myu', 'ミョ': 'myo',
'リャ': 'rya', 'リュ': 'ryu', 'リョ': 'ryo',
'ギャ': 'gya', 'ギュ': 'gyu', 'ギョ': 'gyo',
'ジャ': 'ja', 'ジュ': 'ju', 'ジョ': 'jo',
'ビャ': 'bya', 'ビュ': 'byu', 'ビョ': 'byo',
'ピャ': 'pya', 'ピュ': 'pyu', 'ピョ': 'pyo',
# Katakana small characters and special cases
'': '', # Small tsu (doubles the following consonant)
'': 'ya', '': 'yu', '': 'yo',
# Katakana extras
'': 'vu', 'ファ': 'fa', 'フィ': 'fi', 'フェ': 'fe', 'フォ': 'fo',
'ウィ': 'wi', 'ウェ': 'we', 'ウォ': 'wo',
# Hiragana characters
'': 'a', '': 'i', '': 'u', '': 'e', '': 'o',
'': 'ka', '': 'ki', '': 'ku', '': 'ke', '': 'ko',
'': 'sa', '': 'shi', '': 'su', '': 'se', '': 'so',
'': 'ta', '': 'chi', '': 'tsu', '': 'te', '': 'to',
'': 'na', '': 'ni', '': 'nu', '': 'ne', '': 'no',
'': 'ha', '': 'hi', '': 'fu', '': 'he', '': 'ho',
'': 'ma', '': 'mi', '': 'mu', '': 'me', '': 'mo',
'': 'ya', '': 'yu', '': 'yo',
'': 'ra', '': 'ri', '': 'ru', '': 're', '': 'ro',
'': 'wa', '': 'wo', '': 'n',
# Hiragana voiced consonants
'': 'ga', '': 'gi', '': 'gu', '': 'ge', '': 'go',
'': 'za', '': 'ji', '': 'zu', '': 'ze', '': 'zo',
'': 'da', '': 'ji', '': 'zu', '': 'de', '': 'do',
'': 'ba', '': 'bi', '': 'bu', '': 'be', '': 'bo',
'': 'pa', '': 'pi', '': 'pu', '': 'pe', '': 'po',
# Hiragana combinations
'きゃ': 'kya', 'きゅ': 'kyu', 'きょ': 'kyo',
'しゃ': 'sha', 'しゅ': 'shu', 'しょ': 'sho',
'ちゃ': 'cha', 'ちゅ': 'chu', 'ちょ': 'cho',
'にゃ': 'nya', 'にゅ': 'nyu', 'にょ': 'nyo',
'ひゃ': 'hya', 'ひゅ': 'hyu', 'ひょ': 'hyo',
'みゃ': 'mya', 'みゅ': 'myu', 'みょ': 'myo',
'りゃ': 'rya', 'りゅ': 'ryu', 'りょ': 'ryo',
'ぎゃ': 'gya', 'ぎゅ': 'gyu', 'ぎょ': 'gyo',
'じゃ': 'ja', 'じゅ': 'ju', 'じょ': 'jo',
'びゃ': 'bya', 'びゅ': 'byu', 'びょ': 'byo',
'ぴゃ': 'pya', 'ぴゅ': 'pyu', 'ぴょ': 'pyo',
# Hiragana small characters and special cases
'': '', # Small tsu (doubles the following consonant)
'': 'ya', '': 'yu', '': 'yo',
# Common punctuation and spaces
' ': ' ', # Japanese space
'': ', ', '': '. ',
}
result = []
i = 0
while i < len(japanese_text):
# Check for small tsu (doubling the following consonant)
if i < len(japanese_text) - 1 and (japanese_text[i] == '' or japanese_text[i] == ''):
if i < len(japanese_text) - 1 and japanese_text[i+1] in kana_map:
next_romaji = kana_map[japanese_text[i+1]]
if next_romaji and next_romaji[0] not in 'aiueon':
result.append(next_romaji[0]) # Double the consonant
i += 1
continue
# Check for combinations with small ya, yu, yo
if i < len(japanese_text) - 1 and japanese_text[i+1] in ('', '', '', '', '', ''):
combo = japanese_text[i:i+2]
if combo in kana_map:
result.append(kana_map[combo])
i += 2
continue
# Regular character
if japanese_text[i] in kana_map:
result.append(kana_map[japanese_text[i]])
else:
# If it's not in our map, keep it as is (might be kanji, romaji, etc.)
result.append(japanese_text[i])
i += 1
return ''.join(result)
def number_to_text(num, ordinal=False):
"""
Convert a number (int or float) to its text representation.
Args:
num: The number to convert
Returns:
str: Text representation of the number
"""
if not isinstance(num, (int, float)):
return "Input must be a number"
# Handle special case of zero
if num == 0:
return "zero"
# Handle negative numbers
negative = num < 0
num = abs(num)
# Handle floats
if isinstance(num, float):
# Split into integer and decimal parts
int_part = int(num)
# Convert both parts
int_text = _int_to_text(int_part)
# Handle decimal part (convert to string and remove '0.')
decimal_str = str(num).split('.')[1]
decimal_text = " point " + " ".join(_digit_to_text(int(digit)) for digit in decimal_str)
result = int_text + decimal_text
else:
# Handle integers
result = _int_to_text(num)
# Add 'negative' prefix for negative numbers
if negative:
result = "negative " + result
return result
def _int_to_text(num):
"""Helper function to convert an integer to text"""
ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
"ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen",
"seventeen", "eighteen", "nineteen"]
tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
if num < 20:
return ones[num]
if num < 100:
return tens[num // 10] + (" " + ones[num % 10] if num % 10 != 0 else "")
if num < 1000:
return ones[num // 100] + " hundred" + (" " + _int_to_text(num % 100) if num % 100 != 0 else "")
if num < 1000000:
return _int_to_text(num // 1000) + " thousand" + (" " + _int_to_text(num % 1000) if num % 1000 != 0 else "")
if num < 1000000000:
return _int_to_text(num // 1000000) + " million" + (" " + _int_to_text(num % 1000000) if num % 1000000 != 0 else "")
return _int_to_text(num // 1000000000) + " billion" + (" " + _int_to_text(num % 1000000000) if num % 1000000000 != 0 else "")
def _digit_to_text(digit):
"""Convert a single digit to text"""
digits = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
return digits[digit]
_whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = {
"en": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
],
}
def expand_abbreviations_multilingual(text, lang="en"):
for regex, replacement in _abbreviations[lang]:
text = re.sub(regex, replacement, text)
return text
_symbols_multilingual = {
"en": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " and "),
("@", " at "),
("%", " percent "),
("#", " hash "),
("$", " dollar "),
("£", " pound "),
("°", " degree "),
]
],
}
def expand_symbols_multilingual(text, lang="en"):
for regex, replacement in _symbols_multilingual[lang]:
text = re.sub(regex, replacement, text)
text = text.replace(" ", " ") # Ensure there are no double spaces
return text.strip()
_ordinal_re = {
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
}
_number_re = re.compile(r"[0-9]+")
_currency_re = {
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
}
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
def _remove_commas(m):
text = m.group(0)
if "," in text:
text = text.replace(",", "")
return text
def _remove_dots(m):
text = m.group(0)
if "." in text:
text = text.replace(".", "")
return text
def _expand_decimal_point(m, lang="en"):
amount = m.group(1).replace(",", ".")
return number_to_text(float(amount))
def _expand_currency(m, lang="en", currency="USD"):
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
full_amount = number_to_text(amount)
and_equivalents = {
"en": ", ",
"es": " con ",
"fr": " et ",
"de": " und ",
"pt": " e ",
"it": " e ",
"pl": ", ",
"cs": ", ",
"ru": ", ",
"nl": ", ",
"ar": ", ",
"tr": ", ",
"hu": ", ",
"ko": ", ",
}
if amount.is_integer():
last_and = full_amount.rfind(and_equivalents[lang])
if last_and != -1:
full_amount = full_amount[:last_and]
return full_amount
def _expand_ordinal(m, lang="en"):
return number_to_text(int(m.group(1)), ordinal=True)
def _expand_number(m, lang="en"):
return number_to_text(int(m.group(0)))
def expand_numbers_multilingual(text, lang="en"):
if lang in ["en", "ru"]:
text = re.sub(_comma_number_re, _remove_commas, text)
else:
text = re.sub(_dot_number_re, _remove_dots, text)
try:
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
except:
pass
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
return text
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text)
def multilingual_cleaners(text, lang):
text = text.replace('"', "")
if lang == "tr":
text = text.replace("İ", "i")
text = text.replace("Ö", "ö")
text = text.replace("Ü", "ü")
text = lowercase(text)
try:
text = expand_numbers_multilingual(text, lang)
except:
pass
try:
text = expand_abbreviations_multilingual(text, lang)
except:
pass
try:
text = expand_symbols_multilingual(text, lang=lang)
except:
pass
text = collapse_whitespace(text)
return text
def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text)
text = collapse_whitespace(text)
return text

View File

@@ -0,0 +1,22 @@
{
"d_ff": 2048,
"d_kv": 64,
"d_model": 768,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "umt5",
"num_decoder_layers": 12,
"num_heads": 12,
"num_layers": 12,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 256384
}

View File

@@ -28,6 +28,9 @@ import logging
import itertools
from torch.nn.functional import interpolate
from einops import rearrange
from comfy.cli_args import args
MMAP_TORCH_FILES = args.mmap_torch_files
ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
@@ -67,12 +70,14 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
raise e
else:
torch_args = {}
if MMAP_TORCH_FILES:
torch_args["mmap"] = True
if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "global_step" in pl_sd:
logging.debug(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:

View File

@@ -43,3 +43,13 @@ class VideoInput(ABC):
components = self.get_components()
return components.images.shape[2], components.images.shape[1]
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
components = self.get_components()
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)

View File

@@ -80,6 +80,38 @@ class VideoFromFile(VideoInput):
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
if container.duration is not None:
return float(container.duration / av.time_base)
# Fallback: calculate from frame count and frame rate
video_stream = next(
(s for s in container.streams if s.type == "video"), None
)
if video_stream and video_stream.frames and video_stream.average_rate:
return float(video_stream.frames / video_stream.average_rate)
# Last resort: decode frames to count them
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_components_internal(self, container: InputContainer) -> VideoComponents:
# Get video frames
frames = []

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import io
import logging
from typing import Optional
from typing import Optional, Union
from comfy.utils import common_upscale
from comfy_api.input_impl import VideoFromFile
from comfy_api.util import VideoContainer, VideoCodec
@@ -14,6 +15,7 @@ from comfy_api_nodes.apis.client import (
UploadRequest,
UploadResponse,
)
from server import PromptServer
import numpy as np
@@ -59,7 +61,9 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
return s
def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
def validate_and_cast_response(
response, timeout: int = None, node_id: Union[str, None] = None
) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor.
Args:
@@ -93,6 +97,10 @@ def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
img = Image.open(io.BytesIO(img_data))
elif image_url:
if node_id:
PromptServer.instance.send_progress_text(
f"Result URL: {image_url}", node_id
)
img_response = requests.get(image_url, timeout=timeout)
if img_response.status_code != 200:
raise ValueError("Failed to download the image")
@@ -314,7 +322,7 @@ def upload_file_to_comfyapi(
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: str,
auth_token: Optional[str] = None,
auth_kwargs: Optional[dict[str,str]] = None,
) -> str:
"""
Uploads a single file to ComfyUI API and returns its download URL.
@@ -323,7 +331,7 @@ def upload_file_to_comfyapi(
file_bytes_io: BytesIO object containing the file data.
filename: The filename of the file.
upload_mime_type: MIME type of the file.
auth_token: Optional authentication token.
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded file.
@@ -337,7 +345,7 @@ def upload_file_to_comfyapi(
response_model=UploadResponse,
),
request=request_object,
auth_token=auth_token,
auth_kwargs=auth_kwargs,
)
response: UploadResponse = operation.execute()
@@ -351,7 +359,7 @@ def upload_file_to_comfyapi(
def upload_video_to_comfyapi(
video: VideoInput,
auth_token: Optional[str] = None,
auth_kwargs: Optional[dict[str,str]] = None,
container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None,
@@ -362,7 +370,7 @@ def upload_video_to_comfyapi(
Args:
video: VideoInput object (Comfy VIDEO type).
auth_token: Optional authentication token.
auth_kwargs: Optional authentication token(s).
container: The video container format to use (default: MP4).
codec: The video codec to use (default: H264).
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
@@ -390,7 +398,7 @@ def upload_video_to_comfyapi(
video_bytes_io.seek(0)
return upload_file_to_comfyapi(
video_bytes_io, filename, upload_mime_type, auth_token
video_bytes_io, filename, upload_mime_type, auth_kwargs
)
@@ -453,7 +461,7 @@ def audio_ndarray_to_bytesio(
def upload_audio_to_comfyapi(
audio: AudioInput,
auth_token: Optional[str] = None,
auth_kwargs: Optional[dict[str,str]] = None,
container_format: str = "mp4",
codec_name: str = "aac",
mime_type: str = "audio/mp4",
@@ -465,7 +473,7 @@ def upload_audio_to_comfyapi(
Args:
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
auth_token: Optional authentication token.
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded audio file.
@@ -477,11 +485,11 @@ def upload_audio_to_comfyapi(
audio_data_np, sample_rate, container_format, codec_name
)
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_token)
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def upload_images_to_comfyapi(
image: torch.Tensor, max_images=8, auth_token=None, mime_type: Optional[str] = None
image: torch.Tensor, max_images=8, auth_kwargs: Optional[dict[str,str]] = None, mime_type: Optional[str] = None
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
@@ -490,7 +498,7 @@ def upload_images_to_comfyapi(
Args:
image: Input torch.Tensor image.
max_images: Maximum number of images to upload.
auth_token: Optional authentication token.
auth_kwargs: Optional authentication token(s).
mime_type: Optional MIME type for the image.
"""
# if batch, try to upload each file if max_images is greater than 0
@@ -521,7 +529,7 @@ def upload_images_to_comfyapi(
response_model=UploadResponse,
),
request=request_object,
auth_token=auth_token,
auth_kwargs=auth_kwargs,
)
response = operation.execute()

View File

@@ -20,7 +20,8 @@ Usage Examples:
# 1. Create the API client
api_client = ApiClient(
base_url="https://api.example.com",
api_key="your_api_key_here",
auth_token="your_auth_token_here",
comfy_api_key="your_comfy_api_key_here",
timeout=30.0,
verify_ssl=True
)
@@ -93,15 +94,19 @@ from __future__ import annotations
import logging
import time
import io
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable
import socket
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
from enum import Enum
import json
import requests
from urllib.parse import urljoin
from urllib.parse import urljoin, urlparse
from pydantic import BaseModel, Field
import uuid # For generating unique operation IDs
from server import PromptServer
from comfy.cli_args import args
from comfy import utils
from . import request_logger
T = TypeVar("T", bound=BaseModel)
R = TypeVar("R", bound=BaseModel)
@@ -110,6 +115,21 @@ P = TypeVar("P", bound=BaseModel) # For poll response
PROGRESS_BAR_MAX = 100
class NetworkError(Exception):
"""Base exception for network-related errors with diagnostic information."""
pass
class LocalNetworkError(NetworkError):
"""Exception raised when local network connectivity issues are detected."""
pass
class ApiServerError(NetworkError):
"""Exception raised when the API server is unreachable but internet is working."""
pass
class EmptyRequest(BaseModel):
"""Base class for empty request bodies.
For GET requests, fields will be sent as query parameters."""
@@ -140,20 +160,36 @@ class HttpMethod(str, Enum):
class ApiClient:
"""
Client for making HTTP requests to an API with authentication and error handling.
Client for making HTTP requests to an API with authentication, error handling, and retry logic.
"""
def __init__(
self,
base_url: str,
api_key: Optional[str] = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
timeout: float = 3600.0,
verify_ssl: bool = True,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_status_codes: Optional[Tuple[int, ...]] = None,
):
self.base_url = base_url
self.api_key = api_key
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
self.timeout = timeout
self.verify_ssl = verify_ssl
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
# 500, 502, 503, 504 (Server Errors)
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
def _generate_operation_id(self, path: str) -> str:
"""Generates a unique operation ID for logging."""
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
def _create_json_payload_args(
self,
@@ -201,11 +237,63 @@ class ApiClient:
"""Get headers for API requests, including authentication if available"""
headers = {"Content-Type": "application/json", "Accept": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
if self.auth_token:
headers["Authorization"] = f"Bearer {self.auth_token}"
elif self.comfy_api_key:
headers["X-API-KEY"] = self.comfy_api_key
return headers
def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
"""
Check connectivity to determine if network issues are local or server-related.
Args:
target_url: URL to check connectivity to
Returns:
Dictionary with connectivity status details
"""
results = {
"internet_accessible": False,
"api_accessible": False,
"is_local_issue": False,
"is_api_issue": False
}
# First check basic internet connectivity using a reliable external site
try:
# Use a reliable external domain for checking basic connectivity
check_response = requests.get("https://www.google.com",
timeout=5.0,
verify=self.verify_ssl)
if check_response.status_code < 500:
results["internet_accessible"] = True
except (requests.RequestException, socket.error):
results["internet_accessible"] = False
results["is_local_issue"] = True
return results
# Now check API server connectivity
try:
# Extract domain from the target URL to do a simpler health check
parsed_url = urlparse(target_url)
api_base = f"{parsed_url.scheme}://{parsed_url.netloc}"
# Try to reach the API domain
api_response = requests.get(f"{api_base}/health", timeout=5.0, verify=self.verify_ssl)
if api_response.status_code < 500:
results["api_accessible"] = True
else:
results["api_accessible"] = False
results["is_api_issue"] = True
except requests.RequestException:
results["api_accessible"] = False
# If we can reach the internet but not the API, it's an API issue
results["is_api_issue"] = True
return results
def request(
self,
method: str,
@@ -216,9 +304,10 @@ class ApiClient:
headers: Optional[Dict[str, str]] = None,
content_type: str = "application/json",
multipart_parser: Callable = None,
retry_count: int = 0, # Used internally for tracking retries
) -> Dict[str, Any]:
"""
Make an HTTP request to the API
Make an HTTP request to the API with automatic retries for transient errors.
Args:
method: HTTP method (GET, POST, etc.)
@@ -228,15 +317,18 @@ class ApiClient:
files: Files to upload
headers: Additional headers
content_type: Content type of the request. Defaults to application/json.
retry_count: Internal parameter for tracking retries, do not set manually
Returns:
Parsed JSON response
Raises:
requests.RequestException: If the request fails
LocalNetworkError: If local network connectivity issues are detected
ApiServerError: If the API server is unreachable but internet is working
Exception: For other request failures
"""
url = urljoin(self.base_url, path)
self.check_auth_token(self.api_key)
self.check_auth(self.auth_token, self.comfy_api_key)
# Combine default headers with any provided headers
request_headers = self.get_headers()
if headers:
@@ -260,6 +352,16 @@ class ApiClient:
else:
payload_args = self._create_json_payload_args(data, request_headers)
operation_id = self._generate_operation_id(path)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=request_headers,
request_params=params,
request_data=data if content_type == "application/json" else "[form-data or other]"
)
try:
response = requests.request(
method=method,
@@ -270,87 +372,365 @@ class ApiClient:
**payload_args,
)
# Check if we should retry based on status code
if (response.status_code in self.retry_status_codes and
retry_count < self.max_retries):
# Calculate delay with exponential backoff
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
f"Request failed with status {response.status_code}. "
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
)
time.sleep(delay)
return self.request(
method=method,
path=path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
# Raise exception for error status codes
response.raise_for_status()
except requests.ConnectionError:
raise Exception(
f"Unable to connect to the API server at {self.base_url}. Please check your internet connection or verify the service is available."
# Log successful response
response_content_to_log = response.content
try:
# Attempt to parse JSON for prettier logging, fallback to raw content
response_content_to_log = response.json()
except json.JSONDecodeError:
pass # Keep as bytes/str if not JSON
request_logger.log_request_response(
operation_id=operation_id,
request_method=method, # Pass request details again for context in log
request_url=url,
response_status_code=response.status_code,
response_headers=dict(response.headers),
response_content=response_content_to_log
)
except requests.Timeout:
raise Exception(
f"Request timed out after {self.timeout} seconds. The server might be experiencing high load or the operation is taking longer than expected."
except requests.ConnectionError as e:
error_message = f"ConnectionError: {str(e)}"
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
error_message=error_message
)
# Only perform connectivity check if we've exhausted all retries
if retry_count >= self.max_retries:
# Check connectivity to determine if it's a local or API issue
connectivity = self._check_connectivity(self.base_url)
if connectivity["is_local_issue"]:
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
elif connectivity["is_api_issue"]:
raise ApiServerError(
f"The API server at {self.base_url} is currently unreachable. "
f"The service may be experiencing issues. Please try again later."
) from e
# If we haven't exhausted retries yet, retry the request
if retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
f"Connection error: {str(e)}. "
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
)
time.sleep(delay)
return self.request(
method=method,
path=path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
# If we've exhausted retries and didn't identify the specific issue,
# raise a generic exception
final_error_message = (
f"Unable to connect to the API server after {self.max_retries} attempts. "
f"Please check your internet connection or try again later."
)
request_logger.log_request_response( # Log final failure
operation_id=operation_id,
request_method=method, request_url=url,
error_message=final_error_message
)
raise Exception(final_error_message) from e
except requests.Timeout as e:
error_message = f"Timeout: {str(e)}"
request_logger.log_request_response(
operation_id=operation_id,
request_method=method, request_url=url,
error_message=error_message
)
# Retry timeouts if we haven't exhausted retries
if retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
f"Request timed out. "
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
)
time.sleep(delay)
return self.request(
method=method,
path=path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
final_error_message = (
f"Request timed out after {self.timeout} seconds and {self.max_retries} retry attempts. "
f"The server might be experiencing high load or the operation is taking longer than expected."
)
request_logger.log_request_response( # Log final failure
operation_id=operation_id,
request_method=method, request_url=url,
error_message=final_error_message
)
raise Exception(final_error_message) from e
except requests.HTTPError as e:
status_code = e.response.status_code if hasattr(e, "response") else None
error_message = f"HTTP Error: {str(e)}"
original_error_message = f"HTTP Error: {str(e)}"
error_content_for_log = None
if hasattr(e, "response") and e.response is not None:
error_content_for_log = e.response.content
try:
error_content_for_log = e.response.json()
except json.JSONDecodeError:
pass
# Try to extract detailed error message from JSON response for user display
# but log the full error content.
user_display_error_message = original_error_message
# Try to extract detailed error message from JSON response
try:
if hasattr(e, "response") and e.response.content:
if hasattr(e, "response") and e.response is not None and e.response.content:
error_json = e.response.json()
if "error" in error_json and "message" in error_json["error"]:
error_message = f"API Error: {error_json['error']['message']}"
user_display_error_message = f"API Error: {error_json['error']['message']}"
if "type" in error_json["error"]:
error_message += f" (Type: {error_json['error']['type']})"
user_display_error_message += f" (Type: {error_json['error']['type']})"
elif isinstance(error_json, dict): # Handle cases where error is just a JSON dict
user_display_error_message = f"API Error: {json.dumps(error_json)}"
else: # Non-dict JSON error
user_display_error_message = f"API Error: {str(error_json)}"
except json.JSONDecodeError:
# If not JSON, use the raw content if it's not too long, or a summary
if hasattr(e, "response") and e.response is not None and e.response.content:
raw_content = e.response.content.decode(errors='ignore')
if len(raw_content) < 200: # Arbitrary limit for display
user_display_error_message = f"API Error (raw): {raw_content}"
else:
error_message = f"API Error: {error_json}"
except Exception as json_error:
# If we can't parse the JSON, fall back to the original error message
logging.debug(
f"[DEBUG] Failed to parse error response: {str(json_error)}"
user_display_error_message = f"API Error (raw, status {status_code})"
request_logger.log_request_response(
operation_id=operation_id,
request_method=method, request_url=url,
response_status_code=status_code,
response_headers=dict(e.response.headers) if hasattr(e, "response") and e.response is not None else None,
response_content=error_content_for_log,
error_message=original_error_message # Log the original exception string as error
)
logging.debug(f"[DEBUG] API Error: {user_display_error_message} (Status: {status_code})")
if hasattr(e, "response") and e.response is not None and e.response.content:
logging.debug(f"[DEBUG] Response content: {e.response.content}")
# Retry if the status code is in our retry list and we haven't exhausted retries
if (status_code in self.retry_status_codes and
retry_count < self.max_retries):
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
f"HTTP error {status_code}. "
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
)
time.sleep(delay)
return self.request(
method=method,
path=path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})")
if hasattr(e, "response") and e.response.content:
logging.debug(f"[DEBUG] Response content: {e.response.content}")
# Specific error messages for common status codes for user display
if status_code == 401:
error_message = "Unauthorized: Please login first to use this node."
if status_code == 402:
error_message = "Payment Required: Please add credits to your account to use this node."
if status_code == 409:
error_message = "There is a problem with your account. Please contact support@comfy.org. "
if status_code == 429:
error_message = "Rate Limit Exceeded: Please try again later."
raise Exception(error_message)
user_display_error_message = "Unauthorized: Please login first to use this node."
elif status_code == 402:
user_display_error_message = "Payment Required: Please add credits to your account to use this node."
elif status_code == 409:
user_display_error_message = "There is a problem with your account. Please contact support@comfy.org."
elif status_code == 429:
user_display_error_message = "Rate Limit Exceeded: Please try again later."
# else, user_display_error_message remains as parsed from response or original HTTPError string
raise Exception(user_display_error_message) # Raise with the user-friendly message
# Parse and return JSON response
if response.content:
return response.json()
return {}
def check_auth_token(self, auth_token):
"""Verify that an auth token is present."""
if auth_token is None:
def check_auth(self, auth_token, comfy_api_key):
"""Verify that an auth token is present or comfy_api_key is present"""
if auth_token is None and comfy_api_key is None:
raise Exception("Unauthorized: Please login first to use this node.")
return auth_token
return auth_token or comfy_api_key
@staticmethod
def upload_file(
upload_url: str,
file: io.BytesIO | str,
content_type: str | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
):
"""Upload a file to the API. Make sure the file has a filename equal to what the url expects.
"""Upload a file to the API with retry logic.
Args:
upload_url: The URL to upload to
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
mime_type: Optional mime type to set for the upload
content_type: Optional mime type to set for the upload
max_retries: Maximum number of retry attempts
retry_delay: Initial delay between retries in seconds
retry_backoff_factor: Multiplier for the delay after each retry
"""
headers = {}
if content_type:
headers["Content-Type"] = content_type
# Prepare the file data
if isinstance(file, io.BytesIO):
file.seek(0) # Ensure we're at the start of the file
data = file.read()
return requests.put(upload_url, data=data, headers=headers)
elif isinstance(file, str):
with open(file, "rb") as f:
data = f.read()
return requests.put(upload_url, data=data, headers=headers)
else:
raise ValueError("File must be either a BytesIO object or a file path string")
# Try the upload with retries
last_exception = None
operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" # Simplified ID for uploads
# Log initial attempt (without full file data for brevity)
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers,
request_data=f"[File data of type {content_type or 'unknown'}, size {len(data)} bytes]"
)
for retry_attempt in range(max_retries + 1):
try:
response = requests.put(upload_url, data=data, headers=headers)
response.raise_for_status()
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT", request_url=upload_url, # For context
response_status_code=response.status_code,
response_headers=dict(response.headers),
response_content="File uploaded successfully." # Or response.text if available
)
return response
except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e:
last_exception = e
error_message_for_log = f"{type(e).__name__}: {str(e)}"
response_content_for_log = None
status_code_for_log = None
headers_for_log = None
if hasattr(e, 'response') and e.response is not None:
status_code_for_log = e.response.status_code
headers_for_log = dict(e.response.headers)
try:
response_content_for_log = e.response.json()
except json.JSONDecodeError:
response_content_for_log = e.response.content
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT", request_url=upload_url,
response_status_code=status_code_for_log,
response_headers=headers_for_log,
response_content=response_content_for_log,
error_message=error_message_for_log
)
if retry_attempt < max_retries:
delay = retry_delay * (retry_backoff_factor ** retry_attempt)
logging.warning(
f"File upload failed: {str(e)}. "
f"Retrying in {delay:.2f}s ({retry_attempt + 1}/{max_retries})"
)
time.sleep(delay)
else:
break # Max retries reached
# If we've exhausted all retries, determine the final error type and raise
final_error_message = f"Failed to upload file after {max_retries + 1} attempts. Error: {str(last_exception)}"
try:
# Check basic internet connectivity
check_response = requests.get("https://www.google.com", timeout=5.0, verify=True) # Assuming verify=True is desired
if check_response.status_code >= 500: # Google itself has an issue (rare)
final_error_message = (f"Failed to upload file. Internet connectivity check to Google failed "
f"(status {check_response.status_code}). Original error: {str(last_exception)}")
# Not raising LocalNetworkError here as Google itself might be down.
# If Google is reachable, the issue is likely with the upload server or a more specific local problem
# not caught by a simple Google ping (e.g., DNS for the specific upload URL, firewall).
# The original last_exception is probably most relevant.
except (requests.RequestException, socket.error) as conn_check_exc:
# Could not reach Google, likely a local network issue
final_error_message = (f"Failed to upload file due to network connectivity issues "
f"(cannot reach Google: {str(conn_check_exc)}). "
f"Original upload error: {str(last_exception)}")
request_logger.log_request_response( # Log final failure reason
operation_id=operation_id,
request_method="PUT", request_url=upload_url,
error_message=final_error_message
)
raise LocalNetworkError(final_error_message) from last_exception
request_logger.log_request_response( # Log final failure reason if not LocalNetworkError
operation_id=operation_id,
request_method="PUT", request_url=upload_url,
error_message=final_error_message
)
raise Exception(final_error_message) from last_exception
class ApiEndpoint(Generic[T, R]):
@@ -392,10 +772,15 @@ class SynchronousOperation(Generic[T, R]):
files: Optional[Dict[str, Any]] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str,str]] = None,
timeout: float = 604800.0,
verify_ssl: bool = True,
content_type: str = "application/json",
multipart_parser: Callable = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
):
self.endpoint = endpoint
self.request = request
@@ -403,21 +788,33 @@ class SynchronousOperation(Generic[T, R]):
self.error = None
self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.timeout = timeout
self.verify_ssl = verify_ssl
self.files = files
self.content_type = content_type
self.multipart_parser = multipart_parser
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
def execute(self, client: Optional[ApiClient] = None) -> R:
"""Execute the API operation using the provided client or create one"""
"""Execute the API operation using the provided client or create one with retry support"""
try:
# Create client if not provided
if client is None:
client = ApiClient(
base_url=self.api_base,
api_key=self.auth_token,
auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
timeout=self.timeout,
verify_ssl=self.verify_ssl,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
)
# Convert request model to dict, but use None for EmptyRequest
@@ -431,11 +828,6 @@ class SynchronousOperation(Generic[T, R]):
if isinstance(value, Enum):
request_dict[key] = value.value
if request_dict:
for key, value in request_dict.items():
if isinstance(value, Enum):
request_dict[key] = value.value
# Debug log for request
logging.debug(
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
@@ -443,7 +835,7 @@ class SynchronousOperation(Generic[T, R]):
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
# Make the request
# Make the request with built-in retry
resp = client.request(
method=self.endpoint.method.value,
path=self.endpoint.path,
@@ -464,8 +856,18 @@ class SynchronousOperation(Generic[T, R]):
# Parse and return the response
return self._parse_response(resp)
except LocalNetworkError as e:
# Propagate specific network error types
logging.error(f"[ERROR] Local network error: {str(e)}")
raise
except ApiServerError as e:
# Propagate API server errors
logging.error(f"[ERROR] API server error: {str(e)}")
raise
except Exception as e:
logging.error(f"[DEBUG] API Exception: {str(e)}")
logging.error(f"[ERROR] API Exception: {str(e)}")
raise Exception(str(e))
def _parse_response(self, resp):
@@ -499,22 +901,42 @@ class PollingOperation(Generic[T, R]):
failed_statuses: list,
status_extractor: Callable[[R], str],
progress_extractor: Callable[[R], float] = None,
result_url_extractor: Callable[[R], str] = None,
request: Optional[T] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str,str]] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
max_retries: int = 3, # Max retries per individual API call
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
estimated_duration: Optional[float] = None,
node_id: Optional[str] = None,
):
self.poll_endpoint = poll_endpoint
self.request = request
self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.poll_interval = poll_interval
self.max_poll_attempts = max_poll_attempts
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
self.estimated_duration = estimated_duration
# Polling configuration
self.status_extractor = status_extractor or (
lambda x: getattr(x, "status", None)
)
self.progress_extractor = progress_extractor
self.result_url_extractor = result_url_extractor
self.node_id = node_id
self.completed_statuses = completed_statuses
self.failed_statuses = failed_statuses
@@ -528,12 +950,48 @@ class PollingOperation(Generic[T, R]):
if client is None:
client = ApiClient(
base_url=self.api_base,
api_key=self.auth_token,
auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
)
return self._poll_until_complete(client)
except LocalNetworkError as e:
# Provide clear message for local network issues
raise Exception(
f"Polling failed due to local network issues. Please check your internet connection. "
f"Details: {str(e)}"
) from e
except ApiServerError as e:
# Provide clear message for API server issues
raise Exception(
f"Polling failed due to API server issues. The service may be experiencing problems. "
f"Please try again later. Details: {str(e)}"
) from e
except Exception as e:
raise Exception(f"Error during polling: {str(e)}")
def _display_text_on_node(self, text: str):
"""Sends text to the client which will be displayed on the node in the UI"""
if not self.node_id:
return
PromptServer.instance.send_progress_text(text, self.node_id)
def _display_time_progress_on_node(self, time_completed: int):
if not self.node_id:
return
if self.estimated_duration is not None:
estimated_time_remaining = max(
0, int(self.estimated_duration) - int(time_completed)
)
message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)"
else:
message = f"Task in progress: {time_completed:.0f}s"
self._display_text_on_node(message)
def _check_task_status(self, response: R) -> TaskStatus:
"""Check task status using the status extractor function"""
try:
@@ -550,10 +1008,13 @@ class PollingOperation(Generic[T, R]):
def _poll_until_complete(self, client: ApiClient) -> R:
"""Poll until the task is complete"""
poll_count = 0
consecutive_errors = 0
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
if self.progress_extractor:
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
while True:
while poll_count < self.max_poll_attempts:
try:
poll_count += 1
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
@@ -580,8 +1041,12 @@ class PollingOperation(Generic[T, R]):
data=request_dict,
)
# Successfully got a response, reset consecutive error count
consecutive_errors = 0
# Parse response
response_obj = self.poll_endpoint.response_model.model_validate(resp)
# Check if task is complete
status = self._check_task_status(response_obj)
logging.debug(f"[DEBUG] Task Status: {status}")
@@ -593,7 +1058,15 @@ class PollingOperation(Generic[T, R]):
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
if status == TaskStatus.COMPLETED:
logging.debug("[DEBUG] Task completed successfully")
message = "Task completed successfully"
if self.result_url_extractor:
result_url = self.result_url_extractor(response_obj)
if result_url:
message = f"Result URL: {result_url}"
else:
message = "Task completed successfully!"
logging.debug(f"[DEBUG] {message}")
self._display_text_on_node(message)
self.final_response = response_obj
if self.progress_extractor:
progress.update(100)
@@ -609,8 +1082,43 @@ class PollingOperation(Generic[T, R]):
logging.debug(
f"[DEBUG] Waiting {self.poll_interval} seconds before next poll"
)
for i in range(int(self.poll_interval)):
time_completed = (poll_count * self.poll_interval) + i
self._display_time_progress_on_node(time_completed)
time.sleep(1)
except (LocalNetworkError, ApiServerError) as e:
# For network-related errors, increment error count and potentially abort
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
raise Exception(
f"Polling aborted after {consecutive_errors} consecutive network errors: {str(e)}"
) from e
# Log the error but continue polling
logging.warning(
f"Network error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
f"Will retry in {self.poll_interval} seconds."
)
time.sleep(self.poll_interval)
except Exception as e:
# For other errors, increment count and potentially abort
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
raise Exception(
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
) from e
logging.error(f"[DEBUG] Polling error: {str(e)}")
raise Exception(f"Error while polling: {str(e)}")
logging.warning(
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
f"Will retry in {self.poll_interval} seconds."
)
time.sleep(self.poll_interval)
# If we've exhausted all polling attempts
raise Exception(
f"Polling timed out after {poll_count} attempts ({poll_count * self.poll_interval} seconds). "
f"The operation may still be running on the server but is taking longer than expected."
)

View File

@@ -81,7 +81,6 @@ class RecraftStyle:
class RecraftIO:
STYLEV3 = "RECRAFT_V3_STYLE"
SVG = "SVG" # TODO: if acceptable, move into ComfyUI's typing class
COLOR = "RECRAFT_COLOR"
CONTROLS = "RECRAFT_CONTROLS"

View File

@@ -0,0 +1,125 @@
import os
import datetime
import json
import logging
import folder_paths
# Get the logger instance
logger = logging.getLogger(__name__)
def get_log_directory():
"""
Ensures the API log directory exists within ComfyUI's temp directory
and returns its path.
"""
base_temp_dir = folder_paths.get_temp_directory()
log_dir = os.path.join(base_temp_dir, "api_logs")
try:
os.makedirs(log_dir, exist_ok=True)
except Exception as e:
logger.error(f"Error creating API log directory {log_dir}: {e}")
# Fallback to base temp directory if sub-directory creation fails
return base_temp_dir
return log_dir
def _format_data_for_logging(data):
"""Helper to format data (dict, str, bytes) for logging."""
if isinstance(data, bytes):
try:
return data.decode('utf-8') # Try to decode as text
except UnicodeDecodeError:
return f"[Binary data of length {len(data)} bytes]"
elif isinstance(data, (dict, list)):
try:
return json.dumps(data, indent=2, ensure_ascii=False)
except TypeError:
return str(data) # Fallback for non-serializable objects
return str(data)
def log_request_response(
operation_id: str,
request_method: str,
request_url: str,
request_headers: dict | None = None,
request_params: dict | None = None,
request_data: any = None,
response_status_code: int | None = None,
response_headers: dict | None = None,
response_content: any = None,
error_message: str | None = None
):
"""
Logs API request and response details to a file in the temp/api_logs directory.
"""
log_dir = get_log_directory()
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log"
filepath = os.path.join(log_dir, filename)
log_content = []
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
log_content.append(f"Operation ID: {operation_id}")
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
log_content.append(f"Method: {request_method}")
log_content.append(f"URL: {request_url}")
if request_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
if request_params:
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
if request_data:
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
if response_status_code is not None:
log_content.append(f"Status Code: {response_status_code}")
if response_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
if response_content:
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
if error_message:
log_content.append(f"Error:\n{error_message}")
try:
with open(filepath, "w", encoding="utf-8") as f:
f.write("\n".join(log_content))
logger.debug(f"API log saved to: {filepath}")
except Exception as e:
logger.error(f"Error writing API log to {filepath}: {e}")
if __name__ == '__main__':
# Example usage (for testing the logger directly)
logger.setLevel(logging.DEBUG)
# Mock folder_paths for direct execution if not running within ComfyUI full context
if not hasattr(folder_paths, 'get_temp_directory'):
class MockFolderPaths:
def get_temp_directory(self):
# Create a local temp dir for testing if needed
p = os.path.join(os.path.dirname(__file__), 'temp_test_logs')
os.makedirs(p, exist_ok=True)
return p
folder_paths = MockFolderPaths()
log_request_response(
operation_id="test_operation_get",
request_method="GET",
request_url="https://api.example.com/test",
request_headers={"Authorization": "Bearer testtoken"},
request_params={"param1": "value1"},
response_status_code=200,
response_content={"message": "Success!"}
)
log_request_response(
operation_id="test_operation_post_error",
request_method="POST",
request_url="https://api.example.com/submit",
request_data={"key": "value", "nested": {"num": 123}},
error_message="Connection timed out"
)
log_request_response(
operation_id="test_binary_response",
request_method="GET",
request_url="https://api.example.com/image.png",
response_status_code=200,
response_content=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR...' # Sample binary data
)

10
comfy_api_nodes/canary.py Normal file
View File

@@ -0,0 +1,10 @@
import av
ver = av.__version__.split(".")
if int(ver[0]) < 14:
raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
if int(ver[0]) == 14 and int(ver[1]) < 2:
raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
NODE_CLASS_MAPPINGS = {}

View File

@@ -1,5 +1,6 @@
import io
from inspect import cleandoc
from typing import Union
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api_nodes.apis.bfl_api import (
BFLStatus,
@@ -30,6 +31,7 @@ import requests
import torch
import base64
import time
from server import PromptServer
def convert_mask_to_image(mask: torch.Tensor):
@@ -42,14 +44,19 @@ def convert_mask_to_image(mask: torch.Tensor):
def handle_bfl_synchronous_operation(
operation: SynchronousOperation, timeout_bfl_calls=360
operation: SynchronousOperation,
timeout_bfl_calls=360,
node_id: Union[str, None] = None,
):
response_api: BFLFluxProGenerateResponse = operation.execute()
return _poll_until_generated(
response_api.polling_url, timeout=timeout_bfl_calls
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
)
def _poll_until_generated(polling_url: str, timeout=360):
def _poll_until_generated(
polling_url: str, timeout=360, node_id: Union[str, None] = None
):
# used bfl-comfy-nodes to verify code implementation:
# https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
start_time = time.time()
@@ -61,11 +68,21 @@ def _poll_until_generated(polling_url: str, timeout=360):
request = requests.Request(method=HttpMethod.GET, url=polling_url)
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
while True:
if node_id:
time_elapsed = time.time() - start_time
PromptServer.instance.send_progress_text(
f"Generating ({time_elapsed:.0f}s)", node_id
)
response = requests.Session().send(request.prepare())
if response.status_code == 200:
result = response.json()
if result["status"] == BFLStatus.ready:
img_url = result["result"]["sample"]
if node_id:
PromptServer.instance.send_progress_text(
f"Result URL: {img_url}", node_id
)
img_response = requests.get(img_url)
return process_image_response(img_response)
elif result["status"] in [
@@ -179,6 +196,8 @@ class FluxProUltraImageNode(ComfyNodeABC):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -211,7 +230,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
seed=0,
image_prompt=None,
image_prompt_strength=0.1,
auth_token=None,
unique_id: Union[str, None] = None,
**kwargs,
):
if image_prompt is None:
@@ -244,9 +263,9 @@ class FluxProUltraImageNode(ComfyNodeABC):
None if image_prompt is None else round(image_prompt_strength, 2)
),
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@@ -319,6 +338,8 @@ class FluxProImageNode(ComfyNodeABC):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -337,7 +358,7 @@ class FluxProImageNode(ComfyNodeABC):
seed=0,
image_prompt=None,
# image_prompt_strength=0.1,
auth_token=None,
unique_id: Union[str, None] = None,
**kwargs,
):
image_prompt = (
@@ -361,9 +382,9 @@ class FluxProImageNode(ComfyNodeABC):
seed=seed,
image_prompt=image_prompt,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@@ -457,10 +478,11 @@ class FluxProExpandNode(ComfyNodeABC):
},
),
},
"optional": {
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -482,7 +504,7 @@ class FluxProExpandNode(ComfyNodeABC):
steps: int,
guidance: float,
seed=0,
auth_token=None,
unique_id: Union[str, None] = None,
**kwargs,
):
image = convert_image_to_base64(image)
@@ -506,9 +528,9 @@ class FluxProExpandNode(ComfyNodeABC):
seed=seed,
image=image,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@@ -568,10 +590,11 @@ class FluxProFillNode(ComfyNodeABC):
},
),
},
"optional": {
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -590,14 +613,14 @@ class FluxProFillNode(ComfyNodeABC):
steps: int,
guidance: float,
seed=0,
auth_token=None,
unique_id: Union[str, None] = None,
**kwargs,
):
# prepare mask
mask = resize_mask_to_image(mask, image)
mask = convert_image_to_base64(convert_mask_to_image(mask))
# make sure image will have alpha channel removed
image = convert_image_to_base64(image[:,:,:,:3])
image = convert_image_to_base64(image[:, :, :, :3])
operation = SynchronousOperation(
endpoint=ApiEndpoint(
@@ -615,9 +638,9 @@ class FluxProFillNode(ComfyNodeABC):
image=image,
mask=mask,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@@ -702,10 +725,11 @@ class FluxProCannyNode(ComfyNodeABC):
},
),
},
"optional": {
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -726,10 +750,10 @@ class FluxProCannyNode(ComfyNodeABC):
steps: int,
guidance: float,
seed=0,
auth_token=None,
unique_id: Union[str, None] = None,
**kwargs,
):
control_image = convert_image_to_base64(control_image[:,:,:,:3])
control_image = convert_image_to_base64(control_image[:, :, :, :3])
preprocessed_image = None
# scale canny threshold between 0-500, to match BFL's API
@@ -763,9 +787,9 @@ class FluxProCannyNode(ComfyNodeABC):
canny_high_threshold=canny_high_threshold,
preprocessed_image=preprocessed_image,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@@ -830,10 +854,11 @@ class FluxProDepthNode(ComfyNodeABC):
},
),
},
"optional": {
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -852,7 +877,7 @@ class FluxProDepthNode(ComfyNodeABC):
steps: int,
guidance: float,
seed=0,
auth_token=None,
unique_id: Union[str, None] = None,
**kwargs,
):
control_image = convert_image_to_base64(control_image[:,:,:,:3])
@@ -878,9 +903,9 @@ class FluxProDepthNode(ComfyNodeABC):
control_image=control_image,
preprocessed_image=preprocessed_image,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)

View File

@@ -23,6 +23,7 @@ from comfy_api_nodes.apinode_utils import (
bytesio_to_image_tensor,
resize_mask_to_image,
)
from server import PromptServer
V1_V1_RES_MAP = {
"Auto":"AUTO",
@@ -232,11 +233,22 @@ def download_and_process_images(image_urls):
return stacked_tensors
def display_image_urls_on_node(image_urls, node_id):
if node_id and image_urls:
if len(image_urls) == 1:
PromptServer.instance.send_progress_text(
f"Generated Image URL:\n{image_urls[0]}", node_id
)
else:
urls_text = "Generated Image URLs:\n" + "\n".join(
f"{i+1}. {url}" for i, url in enumerate(image_urls)
)
PromptServer.instance.send_progress_text(urls_text, node_id)
class IdeogramV1(ComfyNodeABC):
"""
Generates images synchronously using the Ideogram V1 model.
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
Generates images using the Ideogram V1 model.
"""
def __init__(self):
@@ -303,7 +315,11 @@ class IdeogramV1(ComfyNodeABC):
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = (IO.IMAGE,)
@@ -321,7 +337,8 @@ class IdeogramV1(ComfyNodeABC):
seed=0,
negative_prompt="",
num_images=1,
auth_token=None,
unique_id=None,
**kwargs,
):
# Determine the model based on turbo setting
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
@@ -347,7 +364,7 @@ class IdeogramV1(ComfyNodeABC):
negative_prompt=negative_prompt if negative_prompt else None,
)
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
response = operation.execute()
@@ -360,14 +377,13 @@ class IdeogramV1(ComfyNodeABC):
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id)
return (download_and_process_images(image_urls),)
class IdeogramV2(ComfyNodeABC):
"""
Generates images synchronously using the Ideogram V2 model.
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
Generates images using the Ideogram V2 model.
"""
def __init__(self):
@@ -458,7 +474,11 @@ class IdeogramV2(ComfyNodeABC):
# },
#),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = (IO.IMAGE,)
@@ -479,7 +499,8 @@ class IdeogramV2(ComfyNodeABC):
negative_prompt="",
num_images=1,
color_palette="",
auth_token=None,
unique_id=None,
**kwargs,
):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
resolution = V1_V1_RES_MAP.get(resolution, None)
@@ -519,7 +540,7 @@ class IdeogramV2(ComfyNodeABC):
color_palette=color_palette if color_palette else None,
)
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
response = operation.execute()
@@ -532,14 +553,12 @@ class IdeogramV2(ComfyNodeABC):
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id)
return (download_and_process_images(image_urls),)
class IdeogramV3(ComfyNodeABC):
"""
Generates images synchronously using the Ideogram V3 model.
Supports both regular image generation from text prompts and image editing with mask.
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask.
"""
def __init__(self):
@@ -621,7 +640,11 @@ class IdeogramV3(ComfyNodeABC):
},
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = (IO.IMAGE,)
@@ -641,7 +664,8 @@ class IdeogramV3(ComfyNodeABC):
seed=0,
num_images=1,
rendering_speed="BALANCED",
auth_token=None,
unique_id=None,
**kwargs,
):
# Check if both image and mask are provided for editing mode
if image is not None and mask is not None:
@@ -705,7 +729,7 @@ class IdeogramV3(ComfyNodeABC):
"mask": mask_binary,
},
content_type="multipart/form-data",
auth_token=auth_token,
auth_kwargs=kwargs,
)
elif image is not None or mask is not None:
@@ -746,7 +770,7 @@ class IdeogramV3(ComfyNodeABC):
response_model=IdeogramGenerateResponse,
),
request=gen_request,
auth_token=auth_token,
auth_kwargs=kwargs,
)
# Execute the operation and process response
@@ -760,6 +784,7 @@ class IdeogramV3(ComfyNodeABC):
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id)
return (download_and_process_images(image_urls),)
@@ -774,4 +799,3 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"IdeogramV2": "Ideogram V2",
"IdeogramV3": "Ideogram V3",
}

View File

@@ -6,6 +6,7 @@ For source of truth on the allowed permutations of request fields, please refere
from __future__ import annotations
from typing import Optional, TypeVar, Any
from collections.abc import Callable
import math
import logging
@@ -64,6 +65,12 @@ from comfy_api_nodes.apinode_utils import (
download_url_to_image_tensor,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api_nodes.util.validation_utils import (
validate_image_dimensions,
validate_image_aspect_ratio,
validate_video_dimensions,
validate_video_duration,
)
from comfy_api.input.basic_types import AudioInput
from comfy_api.input.video_types import VideoInput
from comfy_api.input_impl import VideoFromFile
@@ -79,13 +86,20 @@ PATH_CHARACTER_IMAGE = f"/proxy/kling/{KLING_API_VERSION}/images/generations"
PATH_VIRTUAL_TRY_ON = f"/proxy/kling/{KLING_API_VERSION}/images/kolors-virtual-try-on"
PATH_IMAGE_GENERATIONS = f"/proxy/kling/{KLING_API_VERSION}/images/generations"
MAX_PROMPT_LENGTH_T2V = 2500
MAX_PROMPT_LENGTH_I2V = 500
MAX_PROMPT_LENGTH_IMAGE_GEN = 500
MAX_NEGATIVE_PROMPT_LENGTH_IMAGE_GEN = 200
MAX_PROMPT_LENGTH_LIP_SYNC = 120
AVERAGE_DURATION_T2V = 319
AVERAGE_DURATION_I2V = 164
AVERAGE_DURATION_LIP_SYNC = 455
AVERAGE_DURATION_VIRTUAL_TRY_ON = 19
AVERAGE_DURATION_IMAGE_GEN = 32
AVERAGE_DURATION_VIDEO_EFFECTS = 320
AVERAGE_DURATION_VIDEO_EXTEND = 320
R = TypeVar("R")
@@ -95,7 +109,13 @@ class KlingApiError(Exception):
pass
def poll_until_finished(auth_token: str, api_endpoint: ApiEndpoint[Any, R]) -> R:
def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> R:
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
return PollingOperation(
poll_endpoint=api_endpoint,
@@ -108,7 +128,10 @@ def poll_until_finished(auth_token: str, api_endpoint: ApiEndpoint[Any, R]) -> R
if response.data and response.data.task_status
else None
),
auth_token=auth_token,
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
estimated_duration=estimated_duration,
node_id=node_id,
).execute()
@@ -184,6 +207,18 @@ def validate_image_result_response(response) -> None:
raise KlingApiError(error_msg)
def validate_input_image(image: torch.Tensor) -> None:
"""
Validates the input image adheres to the expectations of the Kling API:
- The image resolution should not be less than 300*300px
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
"""
validate_image_dimensions(image, min_width=300, min_height=300)
validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5)
def get_camera_control_input_config(
tooltip: str, default: float = 0.0
) -> tuple[IO, InputTypeOptions]:
@@ -200,7 +235,9 @@ def get_camera_control_input_config(
def get_video_from_response(response) -> KlingVideoResult:
"""Returns the first video object from the Kling video generation task result."""
"""Returns the first video object from the Kling video generation task result.
Will raise an error if the response is not valid.
"""
video = response.data.task_result.videos[0]
logging.info(
"Kling task %s succeeded. Video URL: %s", response.data.task_id, video.url
@@ -208,12 +245,37 @@ def get_video_from_response(response) -> KlingVideoResult:
return video
def get_video_url_from_response(response) -> Optional[str]:
"""Returns the first video url from the Kling video generation task result.
Will not raise an error if the response is not valid.
"""
if response and is_valid_video_response(response):
return str(get_video_from_response(response).url)
else:
return None
def get_images_from_response(response) -> list[KlingImageResult]:
"""Returns the list of image objects from the Kling image generation task result.
Will raise an error if the response is not valid.
"""
images = response.data.task_result.images
logging.info("Kling task %s succeeded. Images: %s", response.data.task_id, images)
return images
def get_images_urls_from_response(response) -> Optional[str]:
"""Returns the list of image urls from the Kling image generation task result.
Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls.
"""
if response and is_valid_image_response(response):
images = get_images_from_response(response)
image_urls = [str(image.url) for image in images]
return "\n".join(image_urls)
else:
return None
def video_result_to_node_output(
video: KlingVideoResult,
) -> tuple[VideoFromFile, str, str]:
@@ -285,6 +347,7 @@ class KlingCameraControls(KlingNodeBase):
RETURN_TYPES = ("CAMERA_CONTROL",)
RETURN_NAMES = ("camera_control",)
FUNCTION = "main"
API_NODE = False # This is just a helper node, it doesn't make an API call
@classmethod
def VALIDATE_INPUTS(
@@ -391,22 +454,31 @@ class KlingTextToVideoNode(KlingNodeBase):
},
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
RETURN_NAMES = ("VIDEO", "video_id", "duration")
DESCRIPTION = "Kling Text to Video Node"
def get_response(self, task_id: str, auth_token: str) -> KlingText2VideoResponse:
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingText2VideoResponse:
return poll_until_finished(
auth_token,
auth_kwargs,
ApiEndpoint(
path=f"{PATH_TEXT_TO_VIDEO}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingText2VideoResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
node_id=node_id,
)
def api_call(
@@ -419,7 +491,8 @@ class KlingTextToVideoNode(KlingNodeBase):
camera_control: Optional[KlingCameraControl] = None,
model_name: Optional[str] = None,
duration: Optional[str] = None,
auth_token: Optional[str] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile, str, str]:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
if model_name is None:
@@ -441,14 +514,16 @@ class KlingTextToVideoNode(KlingNodeBase):
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
camera_control=camera_control,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_token)
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
@@ -495,7 +570,11 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
},
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text."
@@ -507,7 +586,8 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
cfg_scale: float,
aspect_ratio: str,
camera_control: Optional[KlingCameraControl] = None,
auth_token: Optional[str] = None,
unique_id: Optional[str] = None,
**kwargs,
):
return super().api_call(
model_name=KlingVideoGenModelName.kling_v1,
@@ -518,7 +598,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
prompt=prompt,
negative_prompt=negative_prompt,
camera_control=camera_control,
auth_token=auth_token,
**kwargs,
)
@@ -530,7 +610,10 @@ class KlingImage2VideoNode(KlingNodeBase):
return {
"required": {
"start_frame": model_field_to_node_input(
IO.IMAGE, KlingImage2VideoRequest, "image"
IO.IMAGE,
KlingImage2VideoRequest,
"image",
tooltip="The reference image used to generate the video.",
),
"prompt": model_field_to_node_input(
IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True
@@ -574,22 +657,31 @@ class KlingImage2VideoNode(KlingNodeBase):
enum_type=KlingVideoGenDuration,
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
RETURN_NAMES = ("VIDEO", "video_id", "duration")
DESCRIPTION = "Kling Image to Video Node"
def get_response(self, task_id: str, auth_token: str) -> KlingImage2VideoResponse:
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingImage2VideoResponse:
return poll_until_finished(
auth_token,
auth_kwargs,
ApiEndpoint(
path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}",
method=HttpMethod.GET,
request_model=KlingImage2VideoRequest,
response_model=KlingImage2VideoResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V,
node_id=node_id,
)
def api_call(
@@ -604,12 +696,14 @@ class KlingImage2VideoNode(KlingNodeBase):
duration: str,
camera_control: Optional[KlingCameraControl] = None,
end_frame: Optional[torch.Tensor] = None,
auth_token: Optional[str] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V)
validate_input_image(start_frame)
if camera_control is not None:
# Camera control type for image 2 video is always simple
# Camera control type for image 2 video is always `simple`
camera_control.type = KlingCameraControlType.simple
initial_operation = SynchronousOperation(
@@ -631,18 +725,19 @@ class KlingImage2VideoNode(KlingNodeBase):
negative_prompt=negative_prompt if negative_prompt else None,
cfg_scale=cfg_scale,
mode=KlingVideoGenMode(mode),
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
duration=KlingVideoGenDuration(duration),
camera_control=camera_control,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_token)
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
@@ -692,7 +787,11 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
},
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image."
@@ -705,7 +804,8 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
cfg_scale: float,
aspect_ratio: str,
camera_control: KlingCameraControl,
auth_token: Optional[str] = None,
unique_id: Optional[str] = None,
**kwargs,
):
return super().api_call(
model_name=KlingVideoGenModelName.kling_v1_5,
@@ -717,7 +817,8 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
prompt=prompt,
negative_prompt=negative_prompt,
camera_control=camera_control,
auth_token=auth_token,
unique_id=unique_id,
**kwargs,
)
@@ -785,7 +886,11 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
},
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last."
@@ -799,7 +904,8 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
cfg_scale: float,
aspect_ratio: str,
mode: str,
auth_token: Optional[str] = None,
unique_id: Optional[str] = None,
**kwargs,
):
mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[
mode
@@ -814,7 +920,8 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
aspect_ratio=aspect_ratio,
duration=duration,
end_frame=end_frame,
auth_token=auth_token,
unique_id=unique_id,
**kwargs,
)
@@ -844,22 +951,31 @@ class KlingVideoExtendNode(KlingNodeBase):
IO.STRING, KlingVideoExtendRequest, "video_id", forceInput=True
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
RETURN_NAMES = ("VIDEO", "video_id", "duration")
DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes."
def get_response(self, task_id: str, auth_token: str) -> KlingVideoExtendResponse:
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingVideoExtendResponse:
return poll_until_finished(
auth_token,
auth_kwargs,
ApiEndpoint(
path=f"{PATH_VIDEO_EXTEND}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingVideoExtendResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND,
node_id=node_id,
)
def api_call(
@@ -868,7 +984,8 @@ class KlingVideoExtendNode(KlingNodeBase):
negative_prompt: str,
cfg_scale: float,
video_id: str,
auth_token: Optional[str] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile, str, str]:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
initial_operation = SynchronousOperation(
@@ -884,14 +1001,16 @@ class KlingVideoExtendNode(KlingNodeBase):
cfg_scale=cfg_scale,
video_id=video_id,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_token)
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
@@ -904,15 +1023,20 @@ class KlingVideoEffectsBase(KlingNodeBase):
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
RETURN_NAMES = ("VIDEO", "video_id", "duration")
def get_response(self, task_id: str, auth_token: str) -> KlingVideoEffectsResponse:
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingVideoEffectsResponse:
return poll_until_finished(
auth_token,
auth_kwargs,
ApiEndpoint(
path=f"{PATH_VIDEO_EFFECTS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingVideoEffectsResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS,
node_id=node_id,
)
def api_call(
@@ -924,7 +1048,8 @@ class KlingVideoEffectsBase(KlingNodeBase):
image_1: torch.Tensor,
image_2: Optional[torch.Tensor] = None,
mode: Optional[KlingVideoGenMode] = None,
auth_token: Optional[str] = None,
unique_id: Optional[str] = None,
**kwargs,
):
if dual_character:
request_input_field = KlingDualCharacterEffectInput(
@@ -954,14 +1079,16 @@ class KlingVideoEffectsBase(KlingNodeBase):
effect_scene=effect_scene,
input=request_input_field,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_token)
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
@@ -1002,7 +1129,11 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
enum_type=KlingVideoGenDuration,
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite."
@@ -1017,7 +1148,8 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
model_name: KlingCharacterEffectModelName,
mode: KlingVideoGenMode,
duration: KlingVideoGenDuration,
auth_token: Optional[str] = None,
unique_id: Optional[str] = None,
**kwargs,
):
video, _, duration = super().api_call(
dual_character=True,
@@ -1027,10 +1159,12 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
duration=duration,
image_1=image_left,
image_2=image_right,
auth_token=auth_token,
unique_id=unique_id,
**kwargs,
)
return video, duration
class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
"""Kling Single Image Video Effect Node"""
@@ -1063,7 +1197,11 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
enum_type=KlingVideoGenDuration,
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene."
@@ -1074,7 +1212,8 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
effect_scene: KlingSingleImageEffectsScene,
model_name: KlingSingleImageEffectModelName,
duration: KlingVideoGenDuration,
auth_token: Optional[str] = None,
unique_id: Optional[str] = None,
**kwargs,
):
return super().api_call(
dual_character=False,
@@ -1082,7 +1221,8 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
model_name=model_name,
duration=duration,
image_1=image,
auth_token=auth_token,
unique_id=unique_id,
**kwargs,
)
@@ -1092,6 +1232,17 @@ class KlingLipSyncBase(KlingNodeBase):
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
RETURN_NAMES = ("VIDEO", "video_id", "duration")
def validate_lip_sync_video(self, video: VideoInput):
"""
Validates the input video adheres to the expectations of the Kling Lip Sync API:
- Video length does not exceed 10s and is not shorter than 2s
- Length and width dimensions should both be between 720px and 1920px
See: https://app.klingai.com/global/dev/document-api/apiReference/model/videoTolip
"""
validate_video_dimensions(video, 720, 1920)
validate_video_duration(video, 2, 10)
def validate_text(self, text: str):
if not text:
raise ValueError("Text is required")
@@ -1100,16 +1251,21 @@ class KlingLipSyncBase(KlingNodeBase):
f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters."
)
def get_response(self, task_id: str, auth_token: str) -> KlingLipSyncResponse:
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingLipSyncResponse:
"""Polls the Kling API endpoint until the task reaches a terminal state."""
return poll_until_finished(
auth_token,
auth_kwargs,
ApiEndpoint(
path=f"{PATH_LIP_SYNC}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingLipSyncResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_LIP_SYNC,
node_id=node_id,
)
def api_call(
@@ -1121,18 +1277,20 @@ class KlingLipSyncBase(KlingNodeBase):
text: Optional[str] = None,
voice_speed: Optional[float] = None,
voice_id: Optional[str] = None,
auth_token: Optional[str] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile, str, str]:
if text:
self.validate_text(text)
self.validate_lip_sync_video(video)
# Upload video to Comfy API and get download URL
video_url = upload_video_to_comfyapi(video, auth_token)
video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs)
logging.info("Uploaded video to Comfy API. URL: %s", video_url)
# Upload the audio file to Comfy API and get download URL
if audio:
audio_url = upload_audio_to_comfyapi(audio, auth_token)
audio_url = upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
else:
audio_url = None
@@ -1156,14 +1314,16 @@ class KlingLipSyncBase(KlingNodeBase):
voice_id=voice_id,
),
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_token)
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
@@ -1186,24 +1346,30 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
enum_type=KlingLipSyncVoiceLanguage,
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file."
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
def api_call(
self,
video: VideoInput,
audio: AudioInput,
voice_language: str,
auth_token: Optional[str] = None,
unique_id: Optional[str] = None,
**kwargs,
):
return super().api_call(
video=video,
audio=audio,
voice_language=voice_language,
mode="audio2video",
auth_token=auth_token,
unique_id=unique_id,
**kwargs,
)
@@ -1292,10 +1458,14 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
IO.FLOAT, KlingLipSyncInputObject, "voice_speed", slider=True
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt."
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
def api_call(
self,
@@ -1303,7 +1473,8 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
text: str,
voice: str,
voice_speed: float,
auth_token: Optional[str] = None,
unique_id: Optional[str] = None,
**kwargs,
):
voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice]
return super().api_call(
@@ -1313,7 +1484,8 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
voice_id=voice_id,
voice_speed=voice_speed,
mode="text2video",
auth_token=auth_token,
unique_id=unique_id,
**kwargs,
)
@@ -1350,22 +1522,29 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
enum_type=KlingVirtualTryOnModelName,
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human."
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background."
def get_response(
self, task_id: str, auth_token: Optional[str] = None
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingVirtualTryOnResponse:
return poll_until_finished(
auth_token,
auth_kwargs,
ApiEndpoint(
path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingVirtualTryOnResponse,
),
result_url_extractor=get_images_urls_from_response,
estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON,
node_id=node_id,
)
def api_call(
@@ -1373,7 +1552,8 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
human_image: torch.Tensor,
cloth_image: torch.Tensor,
model_name: KlingVirtualTryOnModelName,
auth_token: Optional[str] = None,
unique_id: Optional[str] = None,
**kwargs,
):
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
@@ -1387,14 +1567,16 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
cloth_image=tensor_to_base64_string(cloth_image),
model_name=model_name,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_token)
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_image_result_response(final_response)
images = get_images_from_response(final_response)
@@ -1462,22 +1644,32 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
"optional": {
"image": (IO.IMAGE, {}),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image."
def get_response(
self, task_id: str, auth_token: Optional[str] = None
self,
task_id: str,
auth_kwargs: Optional[dict[str, str]],
node_id: Optional[str] = None,
) -> KlingImageGenerationsResponse:
return poll_until_finished(
auth_token,
auth_kwargs,
ApiEndpoint(
path=f"{PATH_IMAGE_GENERATIONS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingImageGenerationsResponse,
),
result_url_extractor=get_images_urls_from_response,
estimated_duration=AVERAGE_DURATION_IMAGE_GEN,
node_id=node_id,
)
def api_call(
@@ -1491,7 +1683,8 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
n: int,
aspect_ratio: KlingImageGenAspectRatio,
image: Optional[torch.Tensor] = None,
auth_token: Optional[str] = None,
unique_id: Optional[str] = None,
**kwargs,
):
self.validate_prompt(prompt, negative_prompt)
@@ -1516,14 +1709,16 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
n=n,
aspect_ratio=aspect_ratio,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_token)
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_image_result_response(final_response)
images = get_images_from_response(final_response)

View File

@@ -1,4 +1,6 @@
from __future__ import annotations
from inspect import cleandoc
from typing import Optional
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis.luma_api import (
@@ -34,11 +36,20 @@ from comfy_api_nodes.apinode_utils import (
process_image_response,
validate_string,
)
from server import PromptServer
import requests
import torch
from io import BytesIO
LUMA_T2V_AVERAGE_DURATION = 105
LUMA_I2V_AVERAGE_DURATION = 100
def image_result_url_extractor(response: LumaGeneration):
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
def video_result_url_extractor(response: LumaGeneration):
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
class LumaReferenceNode(ComfyNodeABC):
"""
@@ -201,6 +212,8 @@ class LumaImageGenerationNode(ComfyNodeABC):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -214,7 +227,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
image_luma_ref: LumaReferenceChain = None,
style_image: torch.Tensor = None,
character_image: torch.Tensor = None,
auth_token=None,
unique_id: str = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=True, min_length=3)
@@ -222,19 +235,19 @@ class LumaImageGenerationNode(ComfyNodeABC):
api_image_ref = None
if image_luma_ref is not None:
api_image_ref = self._convert_luma_refs(
image_luma_ref, max_refs=4, auth_token=auth_token
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
)
# handle style_luma_ref
api_style_ref = None
if style_image is not None:
api_style_ref = self._convert_style_image(
style_image, weight=style_image_weight, auth_token=auth_token
style_image, weight=style_image_weight, auth_kwargs=kwargs,
)
# handle character_ref images
character_ref = None
if character_image is not None:
download_urls = upload_images_to_comfyapi(
character_image, max_images=4, auth_token=auth_token
character_image, max_images=4, auth_kwargs=kwargs,
)
character_ref = LumaCharacterRef(
identity0=LumaImageIdentity(images=download_urls)
@@ -255,7 +268,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
style_ref=api_style_ref,
character_ref=character_ref,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
response_api: LumaGeneration = operation.execute()
@@ -269,7 +282,9 @@ class LumaImageGenerationNode(ComfyNodeABC):
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
auth_token=auth_token,
result_url_extractor=image_result_url_extractor,
node_id=unique_id,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
@@ -278,13 +293,13 @@ class LumaImageGenerationNode(ComfyNodeABC):
return (img,)
def _convert_luma_refs(
self, luma_ref: LumaReferenceChain, max_refs: int, auth_token=None
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
):
luma_urls = []
ref_count = 0
for ref in luma_ref.refs:
download_urls = upload_images_to_comfyapi(
ref.image, max_images=1, auth_token=auth_token
ref.image, max_images=1, auth_kwargs=auth_kwargs
)
luma_urls.append(download_urls[0])
ref_count += 1
@@ -293,12 +308,12 @@ class LumaImageGenerationNode(ComfyNodeABC):
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
def _convert_style_image(
self, style_image: torch.Tensor, weight: float, auth_token=None
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
):
chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight)
)
return self._convert_luma_refs(chain, max_refs=1, auth_token=auth_token)
return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
class LumaImageModifyNode(ComfyNodeABC):
@@ -350,6 +365,8 @@ class LumaImageModifyNode(ComfyNodeABC):
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -360,12 +377,12 @@ class LumaImageModifyNode(ComfyNodeABC):
image: torch.Tensor,
image_weight: float,
seed,
auth_token=None,
unique_id: str = None,
**kwargs,
):
# first, upload image
download_urls = upload_images_to_comfyapi(
image, max_images=1, auth_token=auth_token
image, max_images=1, auth_kwargs=kwargs,
)
image_url = download_urls[0]
# next, make Luma call with download url provided
@@ -383,7 +400,7 @@ class LumaImageModifyNode(ComfyNodeABC):
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
),
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
response_api: LumaGeneration = operation.execute()
@@ -397,7 +414,9 @@ class LumaImageModifyNode(ComfyNodeABC):
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
auth_token=auth_token,
result_url_extractor=image_result_url_extractor,
node_id=unique_id,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
@@ -470,6 +489,8 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -483,7 +504,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
loop: bool,
seed,
luma_concepts: LumaConceptChain = None,
auth_token=None,
unique_id: str = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False, min_length=3)
@@ -506,10 +527,13 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
loop=loop,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
response_api: LumaGeneration = operation.execute()
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
@@ -520,7 +544,10 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
auth_token=auth_token,
result_url_extractor=video_result_url_extractor,
node_id=unique_id,
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
@@ -594,6 +621,8 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -608,14 +637,14 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
luma_concepts: LumaConceptChain = None,
auth_token=None,
unique_id: str = None,
**kwargs,
):
if first_image is None and last_image is None:
raise Exception(
"At least one of first_image and last_image requires an input."
)
keyframes = self._convert_to_keyframes(first_image, last_image, auth_token)
keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
@@ -636,10 +665,13 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
keyframes=keyframes,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
response_api: LumaGeneration = operation.execute()
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
@@ -650,7 +682,10 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
auth_token=auth_token,
result_url_extractor=video_result_url_extractor,
node_id=unique_id,
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
@@ -661,7 +696,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
self,
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
auth_token=None,
auth_kwargs: Optional[dict[str,str]] = None,
):
if first_image is None and last_image is None:
return None
@@ -669,12 +704,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
frame1 = None
if first_image is not None:
download_urls = upload_images_to_comfyapi(
first_image, max_images=1, auth_token=auth_token
first_image, max_images=1, auth_kwargs=auth_kwargs,
)
frame0 = LumaImageReference(type="image", url=download_urls[0])
if last_image is not None:
download_urls = upload_images_to_comfyapi(
last_image, max_images=1, auth_token=auth_token
last_image, max_images=1, auth_kwargs=auth_kwargs,
)
frame1 = LumaImageReference(type="image", url=download_urls[0])
return LumaKeyframes(frame0=frame0, frame1=frame1)

View File

@@ -1,3 +1,7 @@
from typing import Union
import logging
import torch
from comfy.comfy_types.node_typing import IO
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
@@ -20,16 +24,19 @@ from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi,
validate_string,
)
from server import PromptServer
import torch
import logging
I2V_AVERAGE_DURATION = 114
T2V_AVERAGE_DURATION = 234
class MinimaxTextToVideoNode:
"""
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
"""
AVERAGE_DURATION = T2V_AVERAGE_DURATION
@classmethod
def INPUT_TYPES(s):
return {
@@ -67,6 +74,8 @@ class MinimaxTextToVideoNode:
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -84,7 +93,8 @@ class MinimaxTextToVideoNode:
model="T2V-01",
image: torch.Tensor=None, # used for ImageToVideo
subject: torch.Tensor=None, # used for SubjectToVideo
auth_token=None,
unique_id: Union[str, None]=None,
**kwargs,
):
'''
Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments.
@@ -94,12 +104,12 @@ class MinimaxTextToVideoNode:
# upload image, if passed in
image_url = None
if image is not None:
image_url = upload_images_to_comfyapi(image, max_images=1, auth_token=auth_token)[0]
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0]
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
subject_reference = None
if subject is not None:
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_token=auth_token)[0]
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0]
subject_reference = [SubjectReferenceItem(image=subject_url)]
@@ -118,7 +128,7 @@ class MinimaxTextToVideoNode:
subject_reference=subject_reference,
prompt_optimizer=None,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
response = video_generate_operation.execute()
@@ -137,7 +147,9 @@ class MinimaxTextToVideoNode:
completed_statuses=["Success"],
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value,
auth_token=auth_token,
estimated_duration=self.AVERAGE_DURATION,
node_id=unique_id,
auth_kwargs=kwargs,
)
task_result = video_generate_operation.execute()
@@ -153,7 +165,7 @@ class MinimaxTextToVideoNode:
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_token=auth_token,
auth_kwargs=kwargs,
)
file_result = file_retrieve_operation.execute()
@@ -163,6 +175,12 @@ class MinimaxTextToVideoNode:
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info(f"Generated video URL: {file_url}")
if unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, unique_id)
video_io = download_url_to_bytesio(file_url)
if video_io is None:
@@ -177,6 +195,8 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
AVERAGE_DURATION = I2V_AVERAGE_DURATION
@classmethod
def INPUT_TYPES(s):
return {
@@ -221,6 +241,8 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -237,6 +259,8 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
AVERAGE_DURATION = T2V_AVERAGE_DURATION
@classmethod
def INPUT_TYPES(s):
return {
@@ -279,6 +303,8 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}

View File

@@ -93,7 +93,11 @@ class OpenAIDalle2(ComfyNodeABC):
},
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = (IO.IMAGE,)
@@ -110,7 +114,8 @@ class OpenAIDalle2(ComfyNodeABC):
mask=None,
n=1,
size="1024x1024",
auth_token=None,
unique_id=None,
**kwargs
):
validate_string(prompt, strip_whitespace=False)
model = "dall-e-2"
@@ -168,12 +173,12 @@ class OpenAIDalle2(ComfyNodeABC):
else None
),
content_type=content_type,
auth_token=auth_token,
auth_kwargs=kwargs,
)
response = operation.execute()
img_tensor = validate_and_cast_response(response)
img_tensor = validate_and_cast_response(response, node_id=unique_id)
return (img_tensor,)
@@ -236,7 +241,11 @@ class OpenAIDalle3(ComfyNodeABC):
},
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = (IO.IMAGE,)
@@ -252,7 +261,8 @@ class OpenAIDalle3(ComfyNodeABC):
style="natural",
quality="standard",
size="1024x1024",
auth_token=None,
unique_id=None,
**kwargs
):
validate_string(prompt, strip_whitespace=False)
model = "dall-e-3"
@@ -273,12 +283,12 @@ class OpenAIDalle3(ComfyNodeABC):
style=style,
seed=seed,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
response = operation.execute()
img_tensor = validate_and_cast_response(response)
img_tensor = validate_and_cast_response(response, node_id=unique_id)
return (img_tensor,)
@@ -366,7 +376,11 @@ class OpenAIGPTImage1(ComfyNodeABC):
},
),
},
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = (IO.IMAGE,)
@@ -385,7 +399,8 @@ class OpenAIGPTImage1(ComfyNodeABC):
mask=None,
n=1,
size="1024x1024",
auth_token=None,
unique_id=None,
**kwargs
):
validate_string(prompt, strip_whitespace=False)
model = "gpt-image-1"
@@ -462,12 +477,12 @@ class OpenAIGPTImage1(ComfyNodeABC):
),
files=files if files else None,
content_type=content_type,
auth_token=auth_token,
auth_kwargs=kwargs,
)
response = operation.execute()
img_tensor = validate_and_cast_response(response)
img_tensor = validate_and_cast_response(response, node_id=unique_id)
return (img_tensor,)

View File

@@ -3,6 +3,7 @@ Pika x ComfyUI API Nodes
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
"""
from __future__ import annotations
import io
from typing import Optional, TypeVar
@@ -120,7 +121,10 @@ class PikaNodeBase(ComfyNodeABC):
RETURN_TYPES = ("VIDEO",)
def poll_for_task_status(
self, task_id: str, auth_token: str
self,
task_id: str,
auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
) -> PikaGenerateResponse:
polling_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
@@ -139,20 +143,26 @@ class PikaNodeBase(ComfyNodeABC):
progress_extractor=lambda response: (
response.progress if hasattr(response, "progress") else None
),
auth_token=auth_token,
auth_kwargs=auth_kwargs,
result_url_extractor=lambda response: (
response.url if hasattr(response, "url") else None
),
node_id=node_id,
estimated_duration=60
)
return polling_operation.execute()
def execute_task(
self,
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
auth_token: Optional[str] = None,
auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
) -> tuple[VideoFromFile]:
"""Executes the initial operation then polls for the task status until it is completed.
Args:
initial_operation: The initial operation to execute.
auth_token: The authentication token to use for the API call.
auth_kwargs: The authentication token(s) to use for the API call.
Returns:
A tuple containing the video file as a VIDEO output.
@@ -164,7 +174,7 @@ class PikaNodeBase(ComfyNodeABC):
raise PikaApiError(error_msg)
task_id = initial_response.video_id
final_response = self.poll_for_task_status(task_id, auth_token)
final_response = self.poll_for_task_status(task_id, auth_kwargs)
if not is_valid_video_response(final_response):
error_msg = (
f"Pika task {task_id} succeeded but no video data found in response."
@@ -193,6 +203,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
@@ -206,7 +217,8 @@ class PikaImageToVideoV2_2(PikaNodeBase):
seed: int,
resolution: str,
duration: int,
auth_token: Optional[str] = None,
unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]:
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image)
@@ -233,10 +245,10 @@ class PikaImageToVideoV2_2(PikaNodeBase):
request=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_token=auth_token,
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_token)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaTextToVideoNodeV2_2(PikaNodeBase):
@@ -259,6 +271,8 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -272,7 +286,8 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
resolution: str,
duration: int,
aspect_ratio: float,
auth_token: Optional[str] = None,
unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]:
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
@@ -289,11 +304,11 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
duration=duration,
aspectRatio=aspect_ratio,
),
auth_token=auth_token,
auth_kwargs=kwargs,
content_type="application/x-www-form-urlencoded",
)
return self.execute_task(initial_operation, auth_token)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaScenesV2_2(PikaNodeBase):
@@ -336,6 +351,8 @@ class PikaScenesV2_2(PikaNodeBase):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -350,12 +367,13 @@ class PikaScenesV2_2(PikaNodeBase):
duration: int,
ingredients_mode: str,
aspect_ratio: float,
unique_id: str,
image_ingredient_1: Optional[torch.Tensor] = None,
image_ingredient_2: Optional[torch.Tensor] = None,
image_ingredient_3: Optional[torch.Tensor] = None,
image_ingredient_4: Optional[torch.Tensor] = None,
image_ingredient_5: Optional[torch.Tensor] = None,
auth_token: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Convert all passed images to BytesIO
all_image_bytes_io = []
@@ -396,10 +414,10 @@ class PikaScenesV2_2(PikaNodeBase):
request=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_token=auth_token,
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_token)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikAdditionsNode(PikaNodeBase):
@@ -434,6 +452,8 @@ class PikAdditionsNode(PikaNodeBase):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -446,7 +466,8 @@ class PikAdditionsNode(PikaNodeBase):
prompt_text: str,
negative_prompt: str,
seed: int,
auth_token: Optional[str] = None,
unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]:
# Convert video to BytesIO
video_bytes_io = io.BytesIO()
@@ -479,10 +500,10 @@ class PikAdditionsNode(PikaNodeBase):
request=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_token=auth_token,
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_token)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaSwapsNode(PikaNodeBase):
@@ -526,6 +547,8 @@ class PikaSwapsNode(PikaNodeBase):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -540,7 +563,8 @@ class PikaSwapsNode(PikaNodeBase):
prompt_text: str,
negative_prompt: str,
seed: int,
auth_token: Optional[str] = None,
unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]:
# Convert video to BytesIO
video_bytes_io = io.BytesIO()
@@ -583,10 +607,10 @@ class PikaSwapsNode(PikaNodeBase):
request=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_token=auth_token,
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_token)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaffectsNode(PikaNodeBase):
@@ -630,6 +654,8 @@ class PikaffectsNode(PikaNodeBase):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -642,7 +668,8 @@ class PikaffectsNode(PikaNodeBase):
prompt_text: str,
negative_prompt: str,
seed: int,
auth_token: Optional[str] = None,
unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]:
initial_operation = SynchronousOperation(
@@ -660,10 +687,10 @@ class PikaffectsNode(PikaNodeBase):
),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data",
auth_token=auth_token,
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_token)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaStartEndFrameNode2_2(PikaNodeBase):
@@ -681,6 +708,8 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -695,7 +724,8 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
seed: int,
resolution: str,
duration: int,
auth_token: Optional[str] = None,
unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]:
pika_files = [
@@ -722,10 +752,10 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
),
files=pika_files,
content_type="multipart/form-data",
auth_token=auth_token,
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_token)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
NODE_CLASS_MAPPINGS = {

View File

@@ -1,5 +1,5 @@
from inspect import cleandoc
from typing import Optional
from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest,
PixverseImageVideoRequest,
@@ -34,11 +34,22 @@ import requests
from io import BytesIO
def upload_image_to_pixverse(image: torch.Tensor, auth_token=None):
AVERAGE_DURATION_T2V = 32
AVERAGE_DURATION_I2V = 30
AVERAGE_DURATION_T2T = 52
def get_video_url_from_response(
response: PixverseGenerationStatusResponse,
) -> Optional[str]:
if response.Resp is None or response.Resp.url is None:
return None
return str(response.Resp.url)
def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
# first, upload image to Pixverse and get image id to use in actual generation call
files = {
"image": tensor_to_bytesio(image)
}
files = {"image": tensor_to_bytesio(image)}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/image/upload",
@@ -49,12 +60,14 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_token=None):
request=EmptyRequest(),
files=files,
content_type="multipart/form-data",
auth_token=auth_token,
auth_kwargs=auth_kwargs,
)
response_upload: PixverseImageUploadResponse = operation.execute()
if response_upload.Resp is None:
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
raise Exception(
f"PixVerse image upload request failed: '{response_upload.ErrMsg}'"
)
return response_upload.Resp.img_id
@@ -73,7 +86,7 @@ class PixverseTemplateNode:
def INPUT_TYPES(s):
return {
"required": {
"template": (list(pixverse_templates.keys()), ),
"template": (list(pixverse_templates.keys()),),
}
}
@@ -87,7 +100,7 @@ class PixverseTemplateNode:
class PixverseTextToVideoNode(ComfyNodeABC):
"""
Generates videos synchronously based on prompt and output_size.
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
@@ -108,9 +121,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
"tooltip": "Prompt for the video generation",
},
),
"aspect_ratio": (
[ratio.value for ratio in PixverseAspectRatio],
),
"aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
@@ -143,11 +154,13 @@ class PixverseTextToVideoNode(ComfyNodeABC):
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
}
)
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -159,9 +172,9 @@ class PixverseTextToVideoNode(ComfyNodeABC):
duration_seconds: int,
motion_mode: str,
seed,
negative_prompt: str=None,
pixverse_template: int=None,
auth_token=None,
negative_prompt: str = None,
pixverse_template: int = None,
unique_id: Optional[str] = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
@@ -190,7 +203,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
template_id=pixverse_template,
seed=seed,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
response_api = operation.execute()
@@ -205,19 +218,27 @@ class PixverseTextToVideoNode(ComfyNodeABC):
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful],
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
failed_statuses=[
PixverseStatus.contents_moderation,
PixverseStatus.failed,
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_token=auth_token,
auth_kwargs=kwargs,
node_id=unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
response_poll = operation.execute()
vid_response = requests.get(response_poll.Resp.url)
return (VideoFromFile(BytesIO(vid_response.content)),)
class PixverseImageToVideoNode(ComfyNodeABC):
"""
Generates videos synchronously based on prompt and output_size.
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
@@ -230,9 +251,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
def INPUT_TYPES(s):
return {
"required": {
"image": (
IO.IMAGE,
),
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
@@ -273,11 +292,13 @@ class PixverseImageToVideoNode(ComfyNodeABC):
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
}
)
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -289,13 +310,13 @@ class PixverseImageToVideoNode(ComfyNodeABC):
duration_seconds: int,
motion_mode: str,
seed,
negative_prompt: str=None,
pixverse_template: int=None,
auth_token=None,
negative_prompt: str = None,
pixverse_template: int = None,
unique_id: Optional[str] = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
img_id = upload_image_to_pixverse(image, auth_token=auth_token)
img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@@ -322,7 +343,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
template_id=pixverse_template,
seed=seed,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
response_api = operation.execute()
@@ -337,9 +358,16 @@ class PixverseImageToVideoNode(ComfyNodeABC):
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful],
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
failed_statuses=[
PixverseStatus.contents_moderation,
PixverseStatus.failed,
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_token=auth_token,
auth_kwargs=kwargs,
node_id=unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V,
)
response_poll = operation.execute()
@@ -349,7 +377,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
class PixverseTransitionVideoNode(ComfyNodeABC):
"""
Generates videos synchronously based on prompt and output_size.
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
@@ -362,12 +390,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
def INPUT_TYPES(s):
return {
"required": {
"first_frame": (
IO.IMAGE,
),
"last_frame": (
IO.IMAGE,
),
"first_frame": (IO.IMAGE,),
"last_frame": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
@@ -407,6 +431,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -419,13 +445,13 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
duration_seconds: int,
motion_mode: str,
seed,
negative_prompt: str=None,
auth_token=None,
negative_prompt: str = None,
unique_id: Optional[str] = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
first_frame_id = upload_image_to_pixverse(first_frame, auth_token=auth_token)
last_frame_id = upload_image_to_pixverse(last_frame, auth_token=auth_token)
first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@@ -452,7 +478,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
negative_prompt=negative_prompt if negative_prompt else None,
seed=seed,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
response_api = operation.execute()
@@ -467,9 +493,16 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful],
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
failed_statuses=[
PixverseStatus.contents_moderation,
PixverseStatus.failed,
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_token=auth_token,
auth_kwargs=kwargs,
node_id=unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
response_poll = operation.execute()

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
from inspect import cleandoc
from typing import Optional
from comfy.utils import ProgressBar
from comfy_extras.nodes_images import SVG # Added
from comfy.comfy_types.node_typing import IO
from comfy_api_nodes.apis.recraft_api import (
RecraftImageGenerationRequest,
@@ -28,9 +30,8 @@ from comfy_api_nodes.apinode_utils import (
resize_mask_to_image,
validate_string,
)
import folder_paths
import json
import os
from server import PromptServer
import torch
from io import BytesIO
from PIL import UnidentifiedImageError
@@ -43,7 +44,7 @@ def handle_recraft_file_request(
total_pixels=4096*4096,
timeout=1024,
request=None,
auth_token=None
auth_kwargs: dict[str,str] = None,
) -> list[BytesIO]:
"""
Handle sending common Recraft file-only request to get back file bytes.
@@ -67,7 +68,7 @@ def handle_recraft_file_request(
request=request,
files=files,
content_type="multipart/form-data",
auth_token=auth_token,
auth_kwargs=auth_kwargs,
multipart_parser=recraft_multipart_parser,
)
response: RecraftImageGenerationResponse = operation.execute()
@@ -162,102 +163,6 @@ class handle_recraft_image_output:
raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.")
class SVG:
"""
Stores SVG representations via a list of BytesIO objects.
"""
def __init__(self, data: list[BytesIO]):
self.data = data
def combine(self, other: SVG):
return SVG(self.data + other.data)
@staticmethod
def combine_all(svgs: list[SVG]):
all_svgs = []
for svg in svgs:
all_svgs.extend(svg.data)
return SVG(all_svgs)
class SaveSVGNode:
"""
Save SVG files on disk.
"""
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
RETURN_TYPES = ()
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "save_svg"
CATEGORY = "api node/image/Recraft"
OUTPUT_NODE = True
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"svg": (RecraftIO.SVG,),
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
}
}
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results = list()
# Prepare metadata JSON
metadata_dict = {}
if prompt is not None:
metadata_dict["prompt"] = prompt
if extra_pnginfo is not None:
metadata_dict.update(extra_pnginfo)
# Convert metadata to JSON string
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
for batch_number, svg_bytes in enumerate(svg.data):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.svg"
# Read SVG content
svg_bytes.seek(0)
svg_content = svg_bytes.read().decode('utf-8')
# Inject metadata if available
if metadata_json:
# Create metadata element with CDATA section
metadata_element = f""" <metadata>
<![CDATA[
{metadata_json}
]]>
</metadata>
"""
# Insert metadata after opening svg tag using regex
import re
svg_content = re.sub(r'(<svg[^>]*>)', r'\1\n' + metadata_element, svg_content)
# Write the modified SVG to file
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
svg_file.write(svg_content.encode('utf-8'))
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "images": results } }
class RecraftColorRGBNode:
"""
Create Recraft Color by choosing specific RGB values.
@@ -485,6 +390,8 @@ class RecraftTextToImageNode:
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -497,7 +404,7 @@ class RecraftTextToImageNode:
recraft_style: RecraftStyle = None,
negative_prompt: str = None,
recraft_controls: RecraftControls = None,
auth_token=None,
unique_id: Optional[str] = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False, max_length=1000)
@@ -530,12 +437,19 @@ class RecraftTextToImageNode:
style_id=recraft_style.style_id,
controls=controls_api,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
response: RecraftImageGenerationResponse = operation.execute()
images = []
urls = []
for data in response.data:
with handle_recraft_image_output():
if unique_id and data.url:
urls.append(data.url)
urls_string = '\n'.join(urls)
PromptServer.instance.send_progress_text(
f"Result URL: {urls_string}", unique_id
)
image = bytesio_to_image_tensor(
download_url_to_bytesio(data.url, timeout=1024)
)
@@ -620,6 +534,7 @@ class RecraftImageToImageNode:
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
@@ -630,7 +545,6 @@ class RecraftImageToImageNode:
n: int,
strength: float,
seed,
auth_token=None,
recraft_style: RecraftStyle = None,
negative_prompt: str = None,
recraft_controls: RecraftControls = None,
@@ -668,7 +582,7 @@ class RecraftImageToImageNode:
image=image[i],
path="/proxy/recraft/images/imageToImage",
request=request,
auth_token=auth_token,
auth_kwargs=kwargs,
)
with handle_recraft_image_output():
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
@@ -736,6 +650,7 @@ class RecraftImageInpaintingNode:
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
@@ -746,7 +661,6 @@ class RecraftImageInpaintingNode:
prompt: str,
n: int,
seed,
auth_token=None,
recraft_style: RecraftStyle = None,
negative_prompt: str = None,
**kwargs,
@@ -781,7 +695,7 @@ class RecraftImageInpaintingNode:
mask=mask[i:i+1],
path="/proxy/recraft/images/inpaint",
request=request,
auth_token=auth_token,
auth_kwargs=kwargs,
)
with handle_recraft_image_output():
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
@@ -796,8 +710,8 @@ class RecraftTextToVectorNode:
Generates SVG synchronously based on prompt and resolution.
"""
RETURN_TYPES = (RecraftIO.SVG,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
RETURN_TYPES = ("SVG",) # Changed
DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Recraft"
@@ -860,6 +774,8 @@ class RecraftTextToVectorNode:
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -872,7 +788,7 @@ class RecraftTextToVectorNode:
seed,
negative_prompt: str = None,
recraft_controls: RecraftControls = None,
auth_token=None,
unique_id: Optional[str] = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False, max_length=1000)
@@ -903,11 +819,18 @@ class RecraftTextToVectorNode:
substyle=recraft_style.substyle,
controls=controls_api,
),
auth_token=auth_token,
auth_kwargs=kwargs,
)
response: RecraftImageGenerationResponse = operation.execute()
svg_data = []
urls = []
for data in response.data:
if unique_id and data.url:
urls.append(data.url)
# Print result on each iteration in case of error
PromptServer.instance.send_progress_text(
f"Result URL: {' '.join(urls)}", unique_id
)
svg_data.append(download_url_to_bytesio(data.url, timeout=1024))
return (SVG(svg_data),)
@@ -918,8 +841,8 @@ class RecraftVectorizeImageNode:
Generates SVG synchronously from an input image.
"""
RETURN_TYPES = (RecraftIO.SVG,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
RETURN_TYPES = ("SVG",) # Changed
DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Recraft"
@@ -934,13 +857,13 @@ class RecraftVectorizeImageNode:
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
image: torch.Tensor,
auth_token=None,
**kwargs,
):
svgs = []
@@ -950,7 +873,7 @@ class RecraftVectorizeImageNode:
sub_bytes = handle_recraft_file_request(
image=image[i],
path="/proxy/recraft/images/vectorize",
auth_token=auth_token,
auth_kwargs=kwargs,
)
svgs.append(SVG(sub_bytes))
pbar.update(1)
@@ -1015,6 +938,7 @@ class RecraftReplaceBackgroundNode:
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
@@ -1024,7 +948,6 @@ class RecraftReplaceBackgroundNode:
prompt: str,
n: int,
seed,
auth_token=None,
recraft_style: RecraftStyle = None,
negative_prompt: str = None,
**kwargs,
@@ -1054,7 +977,7 @@ class RecraftReplaceBackgroundNode:
image=image[i],
path="/proxy/recraft/images/replaceBackground",
request=request,
auth_token=auth_token,
auth_kwargs=kwargs,
)
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
pbar.update(1)
@@ -1084,13 +1007,13 @@ class RecraftRemoveBackgroundNode:
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
image: torch.Tensor,
auth_token=None,
**kwargs,
):
images = []
@@ -1100,7 +1023,7 @@ class RecraftRemoveBackgroundNode:
sub_bytes = handle_recraft_file_request(
image=image[i],
path="/proxy/recraft/images/removeBackground",
auth_token=auth_token,
auth_kwargs=kwargs,
)
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
pbar.update(1)
@@ -1135,13 +1058,13 @@ class RecraftCrispUpscaleNode:
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
image: torch.Tensor,
auth_token=None,
**kwargs,
):
images = []
@@ -1151,7 +1074,7 @@ class RecraftCrispUpscaleNode:
sub_bytes = handle_recraft_file_request(
image=image[i],
path=self.RECRAFT_PATH,
auth_token=auth_token,
auth_kwargs=kwargs,
)
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
pbar.update(1)
@@ -1193,7 +1116,6 @@ NODE_CLASS_MAPPINGS = {
"RecraftStyleV3InfiniteStyleLibrary": RecraftStyleInfiniteStyleLibrary,
"RecraftColorRGB": RecraftColorRGBNode,
"RecraftControls": RecraftControlsNode,
"SaveSVG": SaveSVGNode,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
@@ -1213,5 +1135,4 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"RecraftStyleV3InfiniteStyleLibrary": "Recraft Style - Infinite Style Library",
"RecraftColorRGB": "Recraft Color RGB",
"RecraftControls": "Recraft Controls",
"SaveSVG": "Save SVG",
}

View File

@@ -120,12 +120,13 @@ class StabilityStableImageUltraNode:
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
auth_token=None):
**kwargs):
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
@@ -160,7 +161,7 @@ class StabilityStableImageUltraNode:
),
files=files,
content_type="multipart/form-data",
auth_token=auth_token,
auth_kwargs=kwargs,
)
response_api = operation.execute()
@@ -252,12 +253,13 @@ class StabilityStableImageSD_3_5Node:
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
auth_token=None):
**kwargs):
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
@@ -298,7 +300,7 @@ class StabilityStableImageSD_3_5Node:
),
files=files,
content_type="multipart/form-data",
auth_token=auth_token,
auth_kwargs=kwargs,
)
response_api = operation.execute()
@@ -368,11 +370,12 @@ class StabilityUpscaleConservativeNode:
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
auth_token=None):
**kwargs):
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
@@ -398,7 +401,7 @@ class StabilityUpscaleConservativeNode:
),
files=files,
content_type="multipart/form-data",
auth_token=auth_token,
auth_kwargs=kwargs,
)
response_api = operation.execute()
@@ -473,11 +476,12 @@ class StabilityUpscaleCreativeNode:
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
auth_token=None):
**kwargs):
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
@@ -506,7 +510,7 @@ class StabilityUpscaleCreativeNode:
),
files=files,
content_type="multipart/form-data",
auth_token=auth_token,
auth_kwargs=kwargs,
)
response_api = operation.execute()
@@ -521,7 +525,7 @@ class StabilityUpscaleCreativeNode:
completed_statuses=[StabilityPollStatus.finished],
failed_statuses=[StabilityPollStatus.failed],
status_extractor=lambda x: get_async_dummy_status(x),
auth_token=auth_token,
auth_kwargs=kwargs,
)
response_poll: StabilityResultsGetResponse = operation.execute()
@@ -555,11 +559,12 @@ class StabilityUpscaleFastNode:
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(self, image: torch.Tensor,
auth_token=None):
**kwargs):
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
files = {
@@ -576,7 +581,7 @@ class StabilityUpscaleFastNode:
request=EmptyRequest(),
files=files,
content_type="multipart/form-data",
auth_token=auth_token,
auth_kwargs=kwargs,
)
response_api = operation.execute()

View File

@@ -3,6 +3,7 @@ import logging
import base64
import requests
import torch
from typing import Optional
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl.video_types import VideoFromFile
@@ -24,6 +25,8 @@ from comfy_api_nodes.apinode_utils import (
tensor_to_base64_string
)
AVERAGE_DURATION_VIDEO_GEN = 32
def convert_image_to_base64(image: torch.Tensor):
if image is None:
return None
@@ -31,6 +34,22 @@ def convert_image_to_base64(image: torch.Tensor):
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
return tensor_to_base64_string(scaled_image)
def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]:
if (
poll_response.response
and hasattr(poll_response.response, "videos")
and poll_response.response.videos
and len(poll_response.response.videos) > 0
):
video = poll_response.response.videos[0]
else:
return None
if hasattr(video, "gcsUri") and video.gcsUri:
return str(video.gcsUri)
return None
class VeoVideoGenerationNode(ComfyNodeABC):
"""
Generates videos from text prompts using Google's Veo API.
@@ -114,6 +133,8 @@ class VeoVideoGenerationNode(ComfyNodeABC):
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -133,7 +154,8 @@ class VeoVideoGenerationNode(ComfyNodeABC):
person_generation="ALLOW",
seed=0,
image=None,
auth_token=None,
unique_id: Optional[str] = None,
**kwargs,
):
# Prepare the instances for the request
instances = []
@@ -179,7 +201,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
instances=instances,
parameters=parameters
),
auth_token=auth_token
auth_kwargs=kwargs,
)
initial_response = initial_operation.execute()
@@ -213,8 +235,11 @@ class VeoVideoGenerationNode(ComfyNodeABC):
request=Veo2GenVidPollRequest(
operationName=operation_name
),
auth_token=auth_token,
poll_interval=5.0
auth_kwargs=kwargs,
poll_interval=5.0,
result_url_extractor=get_video_url_from_response,
node_id=unique_id,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
)
# Execute the polling operation

View File

View File

@@ -0,0 +1,100 @@
import logging
from typing import Optional
import torch
from comfy_api.input.video_types import VideoInput
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
if len(image.shape) == 4:
return image.shape[1], image.shape[2]
elif len(image.shape) == 3:
return image.shape[0], image.shape[1]
else:
raise ValueError("Invalid image tensor shape.")
def validate_image_dimensions(
image: torch.Tensor,
min_width: Optional[int] = None,
max_width: Optional[int] = None,
min_height: Optional[int] = None,
max_height: Optional[int] = None,
):
height, width = get_image_dimensions(image)
if min_width is not None and width < min_width:
raise ValueError(f"Image width must be at least {min_width}px, got {width}px")
if max_width is not None and width > max_width:
raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
if min_height is not None and height < min_height:
raise ValueError(
f"Image height must be at least {min_height}px, got {height}px"
)
if max_height is not None and height > max_height:
raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
def validate_image_aspect_ratio(
image: torch.Tensor,
min_aspect_ratio: Optional[float] = None,
max_aspect_ratio: Optional[float] = None,
):
width, height = get_image_dimensions(image)
aspect_ratio = width / height
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
raise ValueError(
f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}"
)
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
raise ValueError(
f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}"
)
def validate_video_dimensions(
video: VideoInput,
min_width: Optional[int] = None,
max_width: Optional[int] = None,
min_height: Optional[int] = None,
max_height: Optional[int] = None,
):
try:
width, height = video.get_dimensions()
except Exception as e:
logging.error("Error getting dimensions of video: %s", e)
return
if min_width is not None and width < min_width:
raise ValueError(f"Video width must be at least {min_width}px, got {width}px")
if max_width is not None and width > max_width:
raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
if min_height is not None and height < min_height:
raise ValueError(
f"Video height must be at least {min_height}px, got {height}px"
)
if max_height is not None and height > max_height:
raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
def validate_video_duration(
video: VideoInput,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
):
try:
duration = video.get_duration()
except Exception as e:
logging.error("Error getting duration of video: %s", e)
return
epsilon = 0.0001
if min_duration is not None and min_duration - epsilon > duration:
raise ValueError(
f"Video duration must be at least {min_duration}s, got {duration}s"
)
if max_duration is not None and duration > max_duration + epsilon:
raise ValueError(
f"Video duration must be at most {max_duration}s, got {duration}s"
)

49
comfy_extras/nodes_ace.py Normal file
View File

@@ -0,0 +1,49 @@
import torch
import comfy.model_management
import node_helpers
class TextEncodeAceStepAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "conditioning"
def encode(self, clip, tags, lyrics, lyrics_strength):
tokens = clip.tokenize(tags, lyrics=lyrics)
conditioning = clip.encode_from_tokens_scheduled(tokens)
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
return (conditioning, )
class EmptyAceStepLatentAudio:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(s):
return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/audio"
def generate(self, seconds, batch_size):
length = int(seconds * 44100 / 512 / 8)
latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
return ({"samples": latent, "type": "audio"}, )
NODE_CLASS_MAPPINGS = {
"TextEncodeAceStepAudio": TextEncodeAceStepAudio,
"EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
}

76
comfy_extras/nodes_apg.py Normal file
View File

@@ -0,0 +1,76 @@
import torch
def project(v0, v1):
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
return v0_parallel, v0_orthogonal
class APG:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "sampling/custom_sampling"
def patch(self, model, eta, norm_threshold, momentum):
running_avg = 0
prev_sigma = None
def pre_cfg_function(args):
nonlocal running_avg, prev_sigma
if len(args["conds_out"]) == 1: return args["conds_out"]
cond = args["conds_out"][0]
uncond = args["conds_out"][1]
sigma = args["sigma"][0]
cond_scale = args["cond_scale"]
if prev_sigma is not None and sigma > prev_sigma:
running_avg = 0
prev_sigma = sigma
guidance = cond - uncond
if momentum != 0:
if not torch.is_tensor(running_avg):
running_avg = guidance
else:
running_avg = momentum * running_avg + guidance
guidance = running_avg
if norm_threshold > 0:
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale = torch.minimum(
torch.ones_like(guidance_norm),
norm_threshold / guidance_norm
)
guidance = guidance * scale
guidance_parallel, guidance_orthogonal = project(guidance, cond)
modified_guidance = guidance_orthogonal + eta * guidance_parallel
modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale
return [modified_cond, uncond] + args["conds_out"][2:]
m = model.clone()
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
return (m,)
NODE_CLASS_MAPPINGS = {
"APG": APG,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"APG": "Adaptive Projected Guidance",
}

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import av
import torchaudio
import torch
import comfy.model_management
@@ -7,7 +8,6 @@ import folder_paths
import os
import io
import json
import struct
import random
import hashlib
import node_helpers
@@ -90,60 +90,118 @@ class VAEDecodeAudio:
return ({"waveform": audio, "sample_rate": 44100}, )
def create_vorbis_comment_block(comment_dict, last_block):
vendor_string = b'ComfyUI'
vendor_length = len(vendor_string)
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
comments = []
for key, value in comment_dict.items():
comment = f"{key}={value}".encode('utf-8')
comments.append(struct.pack('<I', len(comment)) + comment)
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results: list[FileLocator] = []
user_comment_list_length = len(comments)
user_comments = b''.join(comments)
# Prepare metadata dictionary
metadata = {}
if not args.disable_metadata:
if prompt is not None:
metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
comment_data = struct.pack('<I', vendor_length) + vendor_string + struct.pack('<I', user_comment_list_length) + user_comments
if last_block:
id = b'\x84'
else:
id = b'\x04'
comment_block = id + struct.pack('>I', len(comment_data))[1:] + comment_data
# Opus supported sample rates
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
return comment_block
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
output_path = os.path.join(full_output_folder, file)
def insert_or_replace_vorbis_comment(flac_io, comment_dict):
if len(comment_dict) == 0:
return flac_io
# Use original sample rate initially
sample_rate = audio["sample_rate"]
flac_io.seek(4)
# Handle Opus sample rate requirements
if format == "opus":
if sample_rate > 48000:
sample_rate = 48000
elif sample_rate not in OPUS_RATES:
# Find the next highest supported rate
for rate in sorted(OPUS_RATES):
if rate > sample_rate:
sample_rate = rate
break
if sample_rate not in OPUS_RATES: # Fallback if still not supported
sample_rate = 48000
blocks = []
last_block = False
# Resample if necessary
if sample_rate != audio["sample_rate"]:
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
while not last_block:
header = flac_io.read(4)
last_block = (header[0] & 0x80) != 0
block_type = header[0] & 0x7F
block_length = struct.unpack('>I', b'\x00' + header[1:])[0]
block_data = flac_io.read(block_length)
# Create in-memory WAV buffer
wav_buffer = io.BytesIO()
torchaudio.save(wav_buffer, waveform, sample_rate, format="WAV")
wav_buffer.seek(0) # Rewind for reading
if block_type == 4 or block_type == 1:
pass
else:
header = bytes([(header[0] & (~0x80))]) + header[1:]
blocks.append(header + block_data)
# Use PyAV to convert and add metadata
input_container = av.open(wav_buffer)
blocks.append(create_vorbis_comment_block(comment_dict, last_block=True))
# Create output with specified format
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format=format)
new_flac_io = io.BytesIO()
new_flac_io.write(b'fLaC')
for block in blocks:
new_flac_io.write(block)
# Set metadata on the container
for key, value in metadata.items():
output_container.metadata[key] = value
new_flac_io.write(flac_io.read())
return new_flac_io
# Set up the output stream with appropriate properties
input_container.streams.audio[0]
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
out_stream.bit_rate = 96000
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "192k":
out_stream.bit_rate = 192000
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "320k":
out_stream.bit_rate = 320000
else: #format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
# Copy frames from input to output
for frame in input_container.decode(audio=0):
frame.pts = None # Let PyAV handle timestamps
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close containers
output_container.close()
input_container.close()
# Write the output to file
output_buffer.seek(0)
with open(output_path, 'wb') as f:
f.write(output_buffer.getbuffer())
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "audio": results } }
class SaveAudio:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@@ -153,50 +211,70 @@ class SaveAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_audio"
FUNCTION = "save_flac"
OUTPUT_NODE = True
CATEGORY = "audio"
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results: list[FileLocator] = []
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
metadata = {}
if not args.disable_metadata:
if prompt is not None:
metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
class SaveAudioMP3:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.flac"
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
buff = io.BytesIO()
torchaudio.save(buff, waveform, audio["sample_rate"], format="FLAC")
RETURN_TYPES = ()
FUNCTION = "save_mp3"
buff = insert_or_replace_vorbis_comment(buff, metadata)
OUTPUT_NODE = True
with open(os.path.join(full_output_folder, file), 'wb') as f:
f.write(buff.getbuffer())
CATEGORY = "audio"
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
return { "ui": { "audio": results } }
class SaveAudioOpus:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_opus"
OUTPUT_NODE = True
CATEGORY = "audio"
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
class PreviewAudio(SaveAudio):
def __init__(self):
@@ -248,7 +326,20 @@ NODE_CLASS_MAPPINGS = {
"VAEEncodeAudio": VAEEncodeAudio,
"VAEDecodeAudio": VAEDecodeAudio,
"SaveAudio": SaveAudio,
"SaveAudioMP3": SaveAudioMP3,
"SaveAudioOpus": SaveAudioOpus,
"LoadAudio": LoadAudio,
"PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"EmptyLatentAudio": "Empty Latent Audio",
"VAEEncodeAudio": "VAE Encode Audio",
"VAEDecodeAudio": "VAE Decode Audio",
"PreviewAudio": "Preview Audio",
"LoadAudio": "Load Audio",
"SaveAudio": "Save Audio (FLAC)",
"SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)",
}

View File

@@ -0,0 +1,218 @@
import nodes
import torch
import numpy as np
from einops import rearrange
import comfy.model_management
MAX_RESOLUTION = nodes.MAX_RESOLUTION
CAMERA_DICT = {
"base_T_norm": 1.5,
"base_angle": np.pi/3,
"Static": { "angle":[0., 0., 0.], "T":[0., 0., 0.]},
"Pan Up": { "angle":[0., 0., 0.], "T":[0., -1., 0.]},
"Pan Down": { "angle":[0., 0., 0.], "T":[0.,1.,0.]},
"Pan Left": { "angle":[0., 0., 0.], "T":[-1.,0.,0.]},
"Pan Right": { "angle":[0., 0., 0.], "T": [1.,0.,0.]},
"Zoom In": { "angle":[0., 0., 0.], "T": [0.,0.,2.]},
"Zoom Out": { "angle":[0., 0., 0.], "T": [0.,0.,-2.]},
"Anti Clockwise (ACW)": { "angle": [0., 0., -1.], "T":[0., 0., 0.]},
"ClockWise (CW)": { "angle": [0., 0., 1.], "T":[0., 0., 0.]},
}
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
def get_relative_pose(cam_params):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
cam_to_origin = 0
target_cam_c2w = np.array([
[1, 0, 0, 0],
[0, 1, 0, -cam_to_origin],
[0, 0, 1, 0],
[0, 0, 0, 1]
])
abs2rel = target_cam_c2w @ abs_w2cs[0]
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
ret_poses = np.array(ret_poses, dtype=np.float32)
return ret_poses
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
cam_params = [Camera(cam_param) for cam_param in cam_params]
sample_wh_ratio = width / height
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
if pose_wh_ratio > sample_wh_ratio:
resized_ori_w = height * pose_wh_ratio
for cam_param in cam_params:
cam_param.fx = resized_ori_w * cam_param.fx / width
else:
resized_ori_h = width / pose_wh_ratio
for cam_param in cam_params:
cam_param.fy = resized_ori_h * cam_param.fy / height
intrinsic = np.asarray([[cam_param.fx * width,
cam_param.fy * height,
cam_param.cx * width,
cam_param.cy * height]
for cam_param in cam_params], dtype=np.float32)
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
plucker_embedding = plucker_embedding[None]
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
return plucker_embedding
class Camera(object):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
def __init__(self, entry):
fx, fy, cx, cy = entry[1:5]
self.fx = fx
self.fy = fy
self.cx = cx
self.cy = cy
c2w_mat = np.array(entry[7:]).reshape(4, 4)
self.c2w_mat = c2w_mat
self.w2c_mat = np.linalg.inv(c2w_mat)
def ray_condition(K, c2w, H, W, device):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
# c2w: B, V, 4, 4
# K: B, V, 4
B = K.shape[0]
j, i = torch.meshgrid(
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
indexing='ij'
)
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
zs = torch.ones_like(i) # [B, HxW]
xs = (i - cx) / fx * zs
ys = (j - cy) / fy * zs
zs = zs.expand_as(ys)
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
rays_o = c2w[..., :3, 3] # B, V, 3
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
# c2w @ dirctions
rays_dxo = torch.cross(rays_o, rays_d)
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
# plucker = plucker.permute(0, 1, 4, 2, 3)
return plucker
def get_camera_motion(angle, T, speed, n=81):
def compute_R_form_rad_angle(angles):
theta_x, theta_y, theta_z = angles
Rx = np.array([[1, 0, 0],
[0, np.cos(theta_x), -np.sin(theta_x)],
[0, np.sin(theta_x), np.cos(theta_x)]])
Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)],
[0, 1, 0],
[-np.sin(theta_y), 0, np.cos(theta_y)]])
Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0],
[np.sin(theta_z), np.cos(theta_z), 0],
[0, 0, 1]])
R = np.dot(Rz, np.dot(Ry, Rx))
return R
RT = []
for i in range(n):
_angle = (i/n)*speed*(CAMERA_DICT["base_angle"])*angle
R = compute_R_form_rad_angle(_angle)
_T=(i/n)*speed*(CAMERA_DICT["base_T_norm"])*(T.reshape(3,1))
_RT = np.concatenate([R,_T], axis=1)
RT.append(_RT)
RT = np.stack(RT)
return RT
class WanCameraEmbedding:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}),
"width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}),
},
"optional":{
"speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}),
"fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
"fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
"cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
"cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
}
}
RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT")
RETURN_NAMES = ("camera_embedding","width","height","length")
FUNCTION = "run"
CATEGORY = "camera"
def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5):
"""
Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
"""
motion_list = [camera_pose]
speed = speed
angle = np.array(CAMERA_DICT[motion_list[0]]["angle"])
T = np.array(CAMERA_DICT[motion_list[0]]["T"])
RT = get_camera_motion(angle, T, speed, length)
trajs=[]
for cp in RT.tolist():
traj=[fx,fy,cx,cy,0,0]
traj.extend(cp[0])
traj.extend(cp[1])
traj.extend(cp[2])
traj.extend([0,0,0,1])
trajs.append(traj)
cam_params = np.array([[float(x) for x in pose] for pose in trajs])
cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1)
control_camera_video = process_pose_params(cam_params, width=width, height=height)
control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device())
control_camera_video = torch.concat(
[
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
control_camera_video[:, :, 1:]
], dim=2
).transpose(1, 2)
# Reshape, transpose, and view into desired shape
b, f, c, h, w = control_camera_video.shape
control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
return (control_camera_video, width, height, length)
NODE_CLASS_MAPPINGS = {
"WanCameraEmbedding": WanCameraEmbedding,
}

View File

@@ -31,6 +31,7 @@ class T5TokenizerOptions:
}
}
CATEGORY = "_for_testing/conditioning"
RETURN_TYPES = ("CLIP",)
FUNCTION = "set_options"

View File

@@ -77,7 +77,7 @@ class HunyuanImageToVideo:
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
"guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
},
"optional": {"start_image": ("IMAGE", ),
}}
@@ -101,10 +101,12 @@ class HunyuanImageToVideo:
if guidance_type == "v1 (concat)":
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
else:
elif guidance_type == "v2 (replace)":
cond = {'guiding_frame_index': 0}
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
out_latent["noise_mask"] = mask
elif guidance_type == "custom":
cond = {"ref_latent": concat_latent_image}
positive = node_helpers.conditioning_set_values(positive, cond)

View File

@@ -10,6 +10,10 @@ from PIL.PngImagePlugin import PngInfo
import numpy as np
import json
import os
import re
from io import BytesIO
from inspect import cleandoc
import torch
from comfy.comfy_types import FileLocator
@@ -71,6 +75,24 @@ class ImageFromBatch:
s = s_in[batch_index:batch_index + length].clone()
return (s,)
class ImageAddNoise:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
"strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
CATEGORY = "image"
def repeat(self, image, seed, strength):
generator = torch.manual_seed(seed)
s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
return (s,)
class SaveAnimatedWEBP:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@@ -190,10 +212,110 @@ class SaveAnimatedPNG:
return { "ui": { "images": results, "animated": (True,)} }
class SVG:
"""
Stores SVG representations via a list of BytesIO objects.
"""
def __init__(self, data: list[BytesIO]):
self.data = data
def combine(self, other: 'SVG') -> 'SVG':
return SVG(self.data + other.data)
@staticmethod
def combine_all(svgs: list['SVG']) -> 'SVG':
all_svgs_list: list[BytesIO] = []
for svg_item in svgs:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)
class SaveSVGNode:
"""
Save SVG files on disk.
"""
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
RETURN_TYPES = ()
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "save_svg"
CATEGORY = "image/save" # Changed
OUTPUT_NODE = True
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"svg": ("SVG",), # Changed
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
}
}
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results = list()
# Prepare metadata JSON
metadata_dict = {}
if prompt is not None:
metadata_dict["prompt"] = prompt
if extra_pnginfo is not None:
metadata_dict.update(extra_pnginfo)
# Convert metadata to JSON string
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
for batch_number, svg_bytes in enumerate(svg.data):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.svg"
# Read SVG content
svg_bytes.seek(0)
svg_content = svg_bytes.read().decode('utf-8')
# Inject metadata if available
if metadata_json:
# Create metadata element with CDATA section
metadata_element = f""" <metadata>
<![CDATA[
{metadata_json}
]]>
</metadata>
"""
# Insert metadata after opening svg tag using regex with a replacement function
def replacement(match):
# match.group(1) contains the captured <svg> tag
return match.group(1) + '\n' + metadata_element
# Apply the substitution
svg_content = re.sub(r'(<svg[^>]*>)', replacement, svg_content, flags=re.UNICODE)
# Write the modified SVG to file
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
svg_file.write(svg_content.encode('utf-8'))
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "images": results } }
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"ImageFromBatch": ImageFromBatch,
"ImageAddNoise": ImageAddNoise,
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
"SaveSVGNode": SaveSVGNode,
}

View File

@@ -2,6 +2,10 @@ import nodes
import folder_paths
import os
from comfy.comfy_types import IO
from comfy_api.input_impl import VideoFromFile
def normalize_path(path):
return path.replace('\\', '/')
@@ -21,8 +25,8 @@ class Load3D():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
FUNCTION = "process"
EXPERIMENTAL = True
@@ -41,7 +45,14 @@ class Load3D():
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info']
video = None
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
video = VideoFromFile(recording_video_path)
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
class Load3DAnimation():
@classmethod
@@ -59,8 +70,8 @@ class Load3DAnimation():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
FUNCTION = "process"
EXPERIMENTAL = True
@@ -77,7 +88,14 @@ class Load3DAnimation():
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
return output_image, output_mask, model_file, normal_image, image['camera_info']
video = None
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
video = VideoFromFile(recording_video_path)
return output_image, output_mask, model_file, normal_image, image['camera_info'], video
class Preview3D():
@classmethod

View File

@@ -0,0 +1,323 @@
import re
from comfy.comfy_types.node_typing import IO
class StringConcatenate():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string_a": (IO.STRING, {"multiline": True}),
"string_b": (IO.STRING, {"multiline": True}),
"delimiter": (IO.STRING, {"multiline": False, "default": ""})
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string_a, string_b, delimiter, **kwargs):
return delimiter.join((string_a, string_b)),
class StringSubstring():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"start": (IO.INT, {}),
"end": (IO.INT, {}),
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, start, end, **kwargs):
return string[start:end],
class StringLength():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True})
}
}
RETURN_TYPES = (IO.INT,)
RETURN_NAMES = ("length",)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, **kwargs):
length = len(string)
return length,
class CaseConverter():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["UPPERCASE", "lowercase", "Capitalize", "Title Case"]})
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, mode, **kwargs):
if mode == "UPPERCASE":
result = string.upper()
elif mode == "lowercase":
result = string.lower()
elif mode == "Capitalize":
result = string.capitalize()
elif mode == "Title Case":
result = string.title()
else:
result = string
return result,
class StringTrim():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["Both", "Left", "Right"]})
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, mode, **kwargs):
if mode == "Both":
result = string.strip()
elif mode == "Left":
result = string.lstrip()
elif mode == "Right":
result = string.rstrip()
else:
result = string
return result,
class StringReplace():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"find": (IO.STRING, {"multiline": True}),
"replace": (IO.STRING, {"multiline": True})
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, find, replace, **kwargs):
result = string.replace(find, replace)
return result,
class StringContains():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"substring": (IO.STRING, {"multiline": True}),
"case_sensitive": (IO.BOOLEAN, {"default": True})
}
}
RETURN_TYPES = (IO.BOOLEAN,)
RETURN_NAMES = ("contains",)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, substring, case_sensitive, **kwargs):
if case_sensitive:
contains = substring in string
else:
contains = substring.lower() in string.lower()
return contains,
class StringCompare():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string_a": (IO.STRING, {"multiline": True}),
"string_b": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["Starts With", "Ends With", "Equal"]}),
"case_sensitive": (IO.BOOLEAN, {"default": True})
}
}
RETURN_TYPES = (IO.BOOLEAN,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string_a, string_b, mode, case_sensitive, **kwargs):
if case_sensitive:
a = string_a
b = string_b
else:
a = string_a.lower()
b = string_b.lower()
if mode == "Equal":
return a == b,
elif mode == "Starts With":
return a.startswith(b),
elif mode == "Ends With":
return a.endswith(b),
class RegexMatch():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"regex_pattern": (IO.STRING, {"multiline": True}),
"case_insensitive": (IO.BOOLEAN, {"default": True}),
"multiline": (IO.BOOLEAN, {"default": False}),
"dotall": (IO.BOOLEAN, {"default": False})
}
}
RETURN_TYPES = (IO.BOOLEAN,)
RETURN_NAMES = ("matches",)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, **kwargs):
flags = 0
if case_insensitive:
flags |= re.IGNORECASE
if multiline:
flags |= re.MULTILINE
if dotall:
flags |= re.DOTALL
try:
match = re.search(regex_pattern, string, flags)
result = match is not None
except re.error:
result = False
return result,
class RegexExtract():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"regex_pattern": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["First Match", "All Matches", "First Group", "All Groups"]}),
"case_insensitive": (IO.BOOLEAN, {"default": True}),
"multiline": (IO.BOOLEAN, {"default": False}),
"dotall": (IO.BOOLEAN, {"default": False}),
"group_index": (IO.INT, {"default": 1, "min": 0, "max": 100})
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index, **kwargs):
join_delimiter = "\n"
flags = 0
if case_insensitive:
flags |= re.IGNORECASE
if multiline:
flags |= re.MULTILINE
if dotall:
flags |= re.DOTALL
try:
if mode == "First Match":
match = re.search(regex_pattern, string, flags)
if match:
result = match.group(0)
else:
result = ""
elif mode == "All Matches":
matches = re.findall(regex_pattern, string, flags)
if matches:
if isinstance(matches[0], tuple):
result = join_delimiter.join([m[0] for m in matches])
else:
result = join_delimiter.join(matches)
else:
result = ""
elif mode == "First Group":
match = re.search(regex_pattern, string, flags)
if match and len(match.groups()) >= group_index:
result = match.group(group_index)
else:
result = ""
elif mode == "All Groups":
matches = re.finditer(regex_pattern, string, flags)
results = []
for match in matches:
if match.groups() and len(match.groups()) >= group_index:
results.append(match.group(group_index))
result = join_delimiter.join(results)
else:
result = ""
except re.error:
result = ""
return result,
NODE_CLASS_MAPPINGS = {
"StringConcatenate": StringConcatenate,
"StringSubstring": StringSubstring,
"StringLength": StringLength,
"CaseConverter": CaseConverter,
"StringTrim": StringTrim,
"StringReplace": StringReplace,
"StringContains": StringContains,
"StringCompare": StringCompare,
"RegexMatch": RegexMatch,
"RegexExtract": RegexExtract
}
NODE_DISPLAY_NAME_MAPPINGS = {
"StringConcatenate": "Concatenate",
"StringSubstring": "Substring",
"StringLength": "Length",
"CaseConverter": "Case Converter",
"StringTrim": "Trim",
"StringReplace": "Replace",
"StringContains": "Contains",
"StringCompare": "Compare",
"RegexMatch": "Regex Match",
"RegexExtract": "Regex Extract"
}

View File

@@ -297,6 +297,52 @@ class TrimVideoLatent:
samples_out["samples"] = s1[:, :, trim_amount:]
return (samples_out,)
class WanCameraImageToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"camera_conditions": ("WAN_CAMERA_EMBEDDING", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
if camera_conditions is not None:
positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions})
negative = node_helpers.conditioning_set_values(negative, {'camera_conditions': camera_conditions})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"WanImageToVideo": WanImageToVideo,
@@ -305,4 +351,5 @@ NODE_CLASS_MAPPINGS = {
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
"WanVaceToVideo": WanVaceToVideo,
"TrimVideoLatent": TrimVideoLatent,
"WanCameraImageToVideo": WanCameraImageToVideo,
}

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.31"
__version__ = "0.3.34"

View File

@@ -146,6 +146,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
input_data_all[x] = [unique_id]
if h[x] == "AUTH_TOKEN_COMFY_ORG":
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
return input_data_all, missing_keys
map_node_over_list = None #Don't hook this please

View File

@@ -1,28 +0,0 @@
import importlib.util
import shutil
import os
import ctypes
import logging
def fix_pytorch_libomp():
"""
Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed.
"""
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
lib_folder = os.path.join(folder, "lib")
test_file = os.path.join(lib_folder, "fbgemm.dll")
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
if os.path.exists(dest):
break
with open(test_file, "rb") as f:
contents = f.read()
if b"libomp140.x86_64.dll" not in contents:
break
try:
ctypes.cdll.LoadLibrary(test_file)
except FileNotFoundError:
logging.warning("Detected pytorch version with libomp issue, patching.")
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)

View File

@@ -125,13 +125,6 @@ if __name__ == "__main__":
import cuda_malloc
if args.windows_standalone_build:
try:
from fix_torch import fix_pytorch_libomp
fix_pytorch_libomp()
except:
pass
import comfy.utils
import execution
@@ -270,7 +263,7 @@ def start_comfyui(asyncio_loop=None):
q = execution.PromptQueue(prompt_server)
hook_breaker_ac10a0.save_functions()
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
hook_breaker_ac10a0.restore_functions()
cuda_malloc_warning()

View File

@@ -246,6 +246,9 @@ class ConditioningZeroOut:
pooled_output = d.get("pooled_output", None)
if pooled_output is not None:
d["pooled_output"] = torch.zeros_like(pooled_output)
conditioning_lyrics = d.get("conditioning_lyrics", None)
if conditioning_lyrics is not None:
d["conditioning_lyrics"] = torch.zeros_like(conditioning_lyrics)
n = [torch.zeros_like(t[0]), d]
c.append(n)
return (c, )
@@ -917,7 +920,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@@ -1937,7 +1940,7 @@ class ImagePadForOutpaint:
mask[top:top + d2, left:left + d3] = t
return (new_image, mask)
return (new_image, mask.unsqueeze(0))
NODE_CLASS_MAPPINGS = {
@@ -2258,9 +2261,22 @@ def init_builtin_extra_nodes():
"nodes_optimalsteps.py",
"nodes_hidream.py",
"nodes_fresca.py",
"nodes_apg.py",
"nodes_preview_any.py",
"nodes_ace.py",
"nodes_string.py",
"nodes_camera_trajectory.py",
]
import_failed = []
for node_file in extras_files:
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
import_failed.append(node_file)
return import_failed
def init_builtin_api_nodes():
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
api_nodes_files = [
"nodes_ideogram.py",
@@ -2276,11 +2292,10 @@ def init_builtin_extra_nodes():
"nodes_pika.py",
]
import_failed = []
for node_file in extras_files:
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
import_failed.append(node_file)
if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
return api_nodes_files
import_failed = []
for node_file in api_nodes_files:
if not load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
import_failed.append(node_file)
@@ -2288,14 +2303,29 @@ def init_builtin_extra_nodes():
return import_failed
def init_extra_nodes(init_custom_nodes=True):
def init_extra_nodes(init_custom_nodes=True, init_api_nodes=True):
import_failed = init_builtin_extra_nodes()
import_failed_api = []
if init_api_nodes:
import_failed_api = init_builtin_api_nodes()
if init_custom_nodes:
init_external_custom_nodes()
else:
logging.info("Skipping loading of custom nodes")
if len(import_failed_api) > 0:
logging.warning("WARNING: some comfy_api_nodes/ nodes did not import correctly. This may be because they are missing some dependencies.\n")
for node in import_failed_api:
logging.warning("IMPORT FAILED: {}".format(node))
logging.warning("\nThis issue might be caused by new missing dependencies added the last time you updated ComfyUI.")
if args.windows_standalone_build:
logging.warning("Please run the update script: update/update_comfyui.bat")
else:
logging.warning("Please do a: pip install -r requirements.txt")
logging.warning("")
if len(import_failed) > 0:
logging.warning("WARNING: some comfy_extras/ nodes did not import correctly. This may be because they are missing some dependencies.\n")
for node in import_failed:

904
openapi.yaml Normal file
View File

@@ -0,0 +1,904 @@
openapi: 3.0.3
info:
title: ComfyUI API
description: |
API for ComfyUI - A powerful and modular UI for Stable Diffusion.
This API allows you to interact with ComfyUI programmatically, including:
- Submitting workflows for execution
- Managing the execution queue
- Retrieving generated images
- Managing models
- Retrieving node information
version: 1.0.0
license:
name: GNU General Public License v3.0
url: https://github.com/comfyanonymous/ComfyUI/blob/master/LICENSE
servers:
- url: /
description: Default ComfyUI server
tags:
- name: workflow
description: Workflow execution and management
- name: queue
description: Queue management
- name: image
description: Image handling
- name: node
description: Node information
- name: model
description: Model management
- name: system
description: System information
- name: internal
description: Internal API routes
paths:
/prompt:
get:
tags:
- workflow
summary: Get information about current prompt execution
description: Returns information about the current prompt in the execution queue
operationId: getPromptInfo
responses:
'200':
description: Success
content:
application/json:
schema:
$ref: '#/components/schemas/PromptInfo'
post:
tags:
- workflow
summary: Submit a workflow for execution
description: |
Submit a workflow to be executed by the backend.
The workflow is a JSON object describing the nodes and their connections.
operationId: executePrompt
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/PromptRequest'
responses:
'200':
description: Success - Prompt accepted
content:
application/json:
schema:
$ref: '#/components/schemas/PromptResponse'
'400':
description: Invalid prompt
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
/queue:
get:
tags:
- queue
summary: Get queue information
description: Returns information about running and pending items in the queue
operationId: getQueueInfo
responses:
'200':
description: Success
content:
application/json:
schema:
$ref: '#/components/schemas/QueueInfo'
post:
tags:
- queue
summary: Manage queue
description: Clear the queue or delete specific items
operationId: manageQueue
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
clear:
type: boolean
description: If true, clears the entire queue
delete:
type: array
description: Array of prompt IDs to delete from the queue
items:
type: string
format: uuid
responses:
'200':
description: Success
/interrupt:
post:
tags:
- workflow
summary: Interrupt the current execution
description: Interrupts the currently running workflow execution
operationId: interruptExecution
responses:
'200':
description: Success
/free:
post:
tags:
- system
summary: Free resources
description: Unload models and/or free memory
operationId: freeResources
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
unload_models:
type: boolean
description: If true, unloads models from memory
free_memory:
type: boolean
description: If true, frees GPU memory
responses:
'200':
description: Success
/history:
get:
tags:
- workflow
summary: Get execution history
description: Returns the history of executed workflows
operationId: getHistory
parameters:
- name: max_items
in: query
description: Maximum number of history items to return
required: false
schema:
type: integer
format: int32
responses:
'200':
description: Success
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/HistoryItem'
post:
tags:
- workflow
summary: Manage history
description: Clear history or delete specific items
operationId: manageHistory
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
clear:
type: boolean
description: If true, clears the entire history
delete:
type: array
description: Array of prompt IDs to delete from history
items:
type: string
format: uuid
responses:
'200':
description: Success
/history/{prompt_id}:
get:
tags:
- workflow
summary: Get specific history item
description: Returns a specific history item by ID
operationId: getHistoryItem
parameters:
- name: prompt_id
in: path
description: ID of the prompt to retrieve
required: true
schema:
type: string
format: uuid
responses:
'200':
description: Success
content:
application/json:
schema:
$ref: '#/components/schemas/HistoryItem'
/object_info:
get:
tags:
- node
summary: Get all node information
description: Returns information about all available nodes
operationId: getNodeInfo
responses:
'200':
description: Success
content:
application/json:
schema:
type: object
additionalProperties:
$ref: '#/components/schemas/NodeInfo'
/object_info/{node_class}:
get:
tags:
- node
summary: Get specific node information
description: Returns information about a specific node class
operationId: getNodeClassInfo
parameters:
- name: node_class
in: path
description: Name of the node class
required: true
schema:
type: string
responses:
'200':
description: Success
content:
application/json:
schema:
type: object
additionalProperties:
$ref: '#/components/schemas/NodeInfo'
/upload/image:
post:
tags:
- image
summary: Upload an image
description: Uploads an image to the server
operationId: uploadImage
requestBody:
required: true
content:
multipart/form-data:
schema:
type: object
properties:
image:
type: string
format: binary
description: The image file to upload
overwrite:
type: string
description: Whether to overwrite if file exists (true/false)
type:
type: string
enum: [input, temp, output]
description: Type of directory to store the image in
subfolder:
type: string
description: Subfolder to store the image in
responses:
'200':
description: Success
content:
application/json:
schema:
type: object
properties:
name:
type: string
description: Filename of the uploaded image
subfolder:
type: string
description: Subfolder the image was stored in
type:
type: string
description: Type of directory the image was stored in
'400':
description: Bad request
/upload/mask:
post:
tags:
- image
summary: Upload a mask for an image
description: Uploads a mask image and applies it to a referenced original image
operationId: uploadMask
requestBody:
required: true
content:
multipart/form-data:
schema:
type: object
properties:
image:
type: string
format: binary
description: The mask image file to upload
original_ref:
type: string
description: JSON string containing reference to the original image
responses:
'200':
description: Success
content:
application/json:
schema:
type: object
properties:
name:
type: string
description: Filename of the uploaded mask
subfolder:
type: string
description: Subfolder the mask was stored in
type:
type: string
description: Type of directory the mask was stored in
'400':
description: Bad request
/view:
get:
tags:
- image
summary: View an image
description: Retrieves an image from the server
operationId: viewImage
parameters:
- name: filename
in: query
description: Name of the file to retrieve
required: true
schema:
type: string
- name: type
in: query
description: Type of directory to retrieve from
required: false
schema:
type: string
enum: [input, temp, output]
default: output
- name: subfolder
in: query
description: Subfolder to retrieve from
required: false
schema:
type: string
- name: preview
in: query
description: Preview options (format;quality)
required: false
schema:
type: string
- name: channel
in: query
description: Channel to retrieve (rgb, a, rgba)
required: false
schema:
type: string
enum: [rgb, a, rgba]
default: rgba
responses:
'200':
description: Success
content:
image/*:
schema:
type: string
format: binary
'400':
description: Bad request
'404':
description: File not found
/view_metadata/{folder_name}:
get:
tags:
- model
summary: View model metadata
description: Retrieves metadata from a safetensors file
operationId: viewModelMetadata
parameters:
- name: folder_name
in: path
description: Name of the model folder
required: true
schema:
type: string
- name: filename
in: query
description: Name of the safetensors file
required: true
schema:
type: string
responses:
'200':
description: Success
content:
application/json:
schema:
type: object
'404':
description: File not found
/models:
get:
tags:
- model
summary: Get model types
description: Returns a list of available model types
operationId: getModelTypes
responses:
'200':
description: Success
content:
application/json:
schema:
type: array
items:
type: string
/models/{folder}:
get:
tags:
- model
summary: Get models of a specific type
description: Returns a list of available models of a specific type
operationId: getModels
parameters:
- name: folder
in: path
description: Model type folder
required: true
schema:
type: string
responses:
'200':
description: Success
content:
application/json:
schema:
type: array
items:
type: string
'404':
description: Folder not found
/embeddings:
get:
tags:
- model
summary: Get embeddings
description: Returns a list of available embeddings
operationId: getEmbeddings
responses:
'200':
description: Success
content:
application/json:
schema:
type: array
items:
type: string
/extensions:
get:
tags:
- system
summary: Get extensions
description: Returns a list of available extensions
operationId: getExtensions
responses:
'200':
description: Success
content:
application/json:
schema:
type: array
items:
type: string
/system_stats:
get:
tags:
- system
summary: Get system statistics
description: Returns system information including RAM, VRAM, and ComfyUI version
operationId: getSystemStats
responses:
'200':
description: Success
content:
application/json:
schema:
$ref: '#/components/schemas/SystemStats'
/ws:
get:
tags:
- workflow
summary: WebSocket connection
description: |
Establishes a WebSocket connection for real-time communication.
This endpoint is used for receiving progress updates, status changes, and results from workflow executions.
operationId: webSocketConnect
parameters:
- name: clientId
in: query
description: Optional client ID for reconnection
required: false
schema:
type: string
responses:
'101':
description: Switching Protocols to WebSocket
/internal/logs:
get:
tags:
- internal
summary: Get logs
description: Returns system logs as a single string
operationId: getLogs
responses:
'200':
description: Success
content:
application/json:
schema:
type: string
/internal/logs/raw:
get:
tags:
- internal
summary: Get raw logs
description: Returns raw system logs with terminal size information
operationId: getRawLogs
responses:
'200':
description: Success
content:
application/json:
schema:
type: object
properties:
entries:
type: array
items:
type: object
properties:
t:
type: string
description: Timestamp
m:
type: string
description: Message
size:
type: object
properties:
cols:
type: integer
description: Terminal columns
rows:
type: integer
description: Terminal rows
/internal/logs/subscribe:
patch:
tags:
- internal
summary: Subscribe to logs
description: Subscribe or unsubscribe to log updates
operationId: subscribeToLogs
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
clientId:
type: string
description: Client ID
enabled:
type: boolean
description: Whether to enable or disable subscription
responses:
'200':
description: Success
/internal/folder_paths:
get:
tags:
- internal
summary: Get folder paths
description: Returns a map of folder names to their paths
operationId: getFolderPaths
responses:
'200':
description: Success
content:
application/json:
schema:
type: object
additionalProperties:
type: string
/internal/files/{directory_type}:
get:
tags:
- internal
summary: Get files
description: Returns a list of files in a specific directory type
operationId: getFiles
parameters:
- name: directory_type
in: path
description: Type of directory (output, input, temp)
required: true
schema:
type: string
enum: [output, input, temp]
responses:
'200':
description: Success
content:
application/json:
schema:
type: array
items:
type: string
'400':
description: Invalid directory type
components:
schemas:
PromptRequest:
type: object
required:
- prompt
properties:
prompt:
type: object
description: The workflow graph to execute
additionalProperties: true
number:
type: number
description: Priority number for the queue (lower numbers have higher priority)
front:
type: boolean
description: If true, adds the prompt to the front of the queue
extra_data:
type: object
description: Extra data to be associated with the prompt
additionalProperties: true
client_id:
type: string
description: Client ID for attribution of the prompt
PromptResponse:
type: object
properties:
prompt_id:
type: string
format: uuid
description: Unique identifier for the prompt execution
number:
type: number
description: Priority number in the queue
node_errors:
type: object
description: Any errors in the nodes of the prompt
additionalProperties: true
ErrorResponse:
type: object
properties:
error:
type: object
properties:
type:
type: string
description: Error type
message:
type: string
description: Error message
details:
type: string
description: Detailed error information
extra_info:
type: object
description: Additional error information
additionalProperties: true
node_errors:
type: object
description: Node-specific errors
additionalProperties: true
PromptInfo:
type: object
properties:
exec_info:
type: object
properties:
queue_remaining:
type: integer
description: Number of items remaining in the queue
QueueInfo:
type: object
properties:
queue_running:
type: array
items:
type: object
description: Currently running items
additionalProperties: true
queue_pending:
type: array
items:
type: object
description: Pending items in the queue
additionalProperties: true
HistoryItem:
type: object
properties:
prompt_id:
type: string
format: uuid
description: Unique identifier for the prompt
prompt:
type: object
description: The workflow graph that was executed
additionalProperties: true
extra_data:
type: object
description: Additional data associated with the execution
additionalProperties: true
outputs:
type: object
description: Output data from the execution
additionalProperties: true
NodeInfo:
type: object
properties:
input:
type: object
description: Input specifications for the node
additionalProperties: true
input_order:
type: object
description: Order of inputs for display
additionalProperties:
type: array
items:
type: string
output:
type: array
items:
type: string
description: Output types of the node
output_is_list:
type: array
items:
type: boolean
description: Whether each output is a list
output_name:
type: array
items:
type: string
description: Names of the outputs
name:
type: string
description: Internal name of the node
display_name:
type: string
description: Display name of the node
description:
type: string
description: Description of the node
python_module:
type: string
description: Python module implementing the node
category:
type: string
description: Category of the node
output_node:
type: boolean
description: Whether this is an output node
output_tooltips:
type: array
items:
type: string
description: Tooltips for outputs
deprecated:
type: boolean
description: Whether the node is deprecated
experimental:
type: boolean
description: Whether the node is experimental
api_node:
type: boolean
description: Whether this is an API node
SystemStats:
type: object
properties:
system:
type: object
properties:
os:
type: string
description: Operating system
ram_total:
type: number
description: Total system RAM in bytes
ram_free:
type: number
description: Free system RAM in bytes
comfyui_version:
type: string
description: ComfyUI version
python_version:
type: string
description: Python version
pytorch_version:
type: string
description: PyTorch version
embedded_python:
type: boolean
description: Whether using embedded Python
argv:
type: array
items:
type: string
description: Command line arguments
devices:
type: array
items:
type: object
properties:
name:
type: string
description: Device name
type:
type: string
description: Device type
index:
type: integer
description: Device index
vram_total:
type: number
description: Total VRAM in bytes
vram_free:
type: number
description: Free VRAM in bytes
torch_vram_total:
type: number
description: Total VRAM as reported by PyTorch
torch_vram_free:
type: number
description: Free VRAM as reported by PyTorch

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.31"
version = "0.3.34"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.18.9
comfyui-workflow-templates==0.1.11
comfyui-frontend-package==1.19.9
comfyui-workflow-templates==0.1.14
torch
torchsde
torchvision

View File

@@ -101,6 +101,14 @@ prompt_text = """
def queue_prompt(prompt):
p = {"prompt": prompt}
# If the workflow contains API nodes, you can add a Comfy API key to the `extra_data`` field of the payload.
# p["extra_data"] = {
# "api_key_comfy_org": "comfyui-87d01e28d*******************************************************" # replace with real key
# }
# See: https://docs.comfy.org/tutorials/api-nodes/overview
# Generate a key here: https://platform.comfy.org/login
data = json.dumps(p).encode('utf-8')
req = request.Request("http://127.0.0.1:8188/prompt", data=data)
request.urlopen(req)

View File

@@ -32,12 +32,13 @@ from app.frontend_management import FrontendManager
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from typing import Optional
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
TEXT = 3
async def send_socket_catch_exception(function, message):
try:
@@ -878,3 +879,15 @@ class PromptServer():
logging.warning(traceback.format_exc())
return json_data
def send_progress_text(
self, text: Union[bytes, bytearray, str], node_id: str, sid=None
):
if isinstance(text, str):
text = text.encode("utf-8")
node_id_bytes = str(node_id).encode("utf-8")
# Pack the node_id length as a 4-byte unsigned integer, followed by the node_id bytes
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
self.send_sync(BinaryEventTypes.TEXT, message, sid)

74
tests-api/README.md Normal file
View File

@@ -0,0 +1,74 @@
# ComfyUI API Testing
This directory contains tests for validating the ComfyUI OpenAPI specification against a running instance of ComfyUI.
## Setup
1. Install the required dependencies:
```bash
pip install -r requirements.txt
```
2. Make sure you have a running instance of ComfyUI (default: http://127.0.0.1:8188)
## Running the Tests
Run all tests with pytest:
```bash
cd tests-api
pytest
```
Run specific test files:
```bash
pytest test_spec_validation.py
pytest test_endpoint_existence.py
pytest test_schema_validation.py
pytest test_api_by_tag.py
```
Run tests with more verbose output:
```bash
pytest -v
```
## Test Categories
The tests are organized into several categories:
1. **Spec Validation**: Validates that the OpenAPI specification is valid.
2. **Endpoint Existence**: Tests that the endpoints defined in the spec exist on the server.
3. **Schema Validation**: Tests that the server responses match the schemas defined in the spec.
4. **Tag-Based Tests**: Tests that the API's tag organization is consistent.
## Using a Different Server
By default, the tests connect to `http://127.0.0.1:8188`. To test against a different server, set the `COMFYUI_SERVER_URL` environment variable:
```bash
COMFYUI_SERVER_URL=http://example.com:8188 pytest
```
## Test Structure
- `conftest.py`: Contains pytest fixtures used by the tests.
- `utils/`: Contains utility functions for working with the OpenAPI spec.
- `test_*.py`: The actual test files.
- `resources/`: Contains resources used by the tests (e.g., sample workflows).
## Extending the Tests
To add new tests:
1. For testing new endpoints, add them to the appropriate test file based on their category.
2. For testing more complex functionality, create a new test file following the established patterns.
## Notes
- Tests that require a running server will be skipped if the server is not available.
- Some tests may fail if the server doesn't match the specification exactly.
- The tests don't modify any data on the server (they're read-only).

141
tests-api/conftest.py Normal file
View File

@@ -0,0 +1,141 @@
"""
Test fixtures for API testing
"""
import os
import pytest
import yaml
import requests
import logging
from typing import Dict, Any, Generator, Optional
from urllib.parse import urljoin
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Default server configuration
DEFAULT_SERVER_URL = "http://127.0.0.1:8188"
@pytest.fixture(scope="session")
def api_spec_path() -> str:
"""
Get the path to the OpenAPI specification file
Returns:
Path to the OpenAPI specification file
"""
return os.path.abspath(os.path.join(
os.path.dirname(__file__),
"..",
"openapi.yaml"
))
@pytest.fixture(scope="session")
def api_spec(api_spec_path: str) -> Dict[str, Any]:
"""
Load the OpenAPI specification
Args:
api_spec_path: Path to the spec file
Returns:
Parsed OpenAPI specification
"""
with open(api_spec_path, 'r') as f:
return yaml.safe_load(f)
@pytest.fixture(scope="session")
def base_url() -> str:
"""
Get the base URL for the API server
Returns:
Base URL string
"""
# Allow overriding via environment variable
return os.environ.get("COMFYUI_SERVER_URL", DEFAULT_SERVER_URL)
@pytest.fixture(scope="session")
def server_available(base_url: str) -> bool:
"""
Check if the server is available
Args:
base_url: Base URL for the API
Returns:
True if the server is available, False otherwise
"""
try:
response = requests.get(base_url, timeout=2)
return response.status_code == 200
except requests.RequestException:
logger.warning(f"Server at {base_url} is not available")
return False
@pytest.fixture
def api_client(base_url: str) -> Generator[Optional[requests.Session], None, None]:
"""
Create a requests session for API testing
Args:
base_url: Base URL for the API
Yields:
Requests session configured for the API
"""
session = requests.Session()
# Helper function to construct URLs
def get_url(path: str) -> str:
return urljoin(base_url, path)
# Add url helper to the session
session.get_url = get_url # type: ignore
yield session
# Cleanup
session.close()
@pytest.fixture
def api_get_json(api_client: requests.Session):
"""
Helper fixture for making GET requests and parsing JSON responses
Args:
api_client: API client session
Returns:
Function that makes GET requests and returns JSON
"""
def _get_json(path: str, **kwargs):
url = api_client.get_url(path) # type: ignore
response = api_client.get(url, **kwargs)
if response.status_code == 200:
try:
return response.json()
except ValueError:
return None
return None
return _get_json
@pytest.fixture
def require_server(server_available):
"""
Skip tests if server is not available
Args:
server_available: Whether the server is available
"""
if not server_available:
pytest.skip("Server is not available")

View File

@@ -0,0 +1,6 @@
pytest>=7.0.0
pytest-asyncio>=0.21.0
openapi-spec-validator>=0.5.0
jsonschema>=4.17.0
requests>=2.28.0
pyyaml>=6.0.0

View File

@@ -0,0 +1,279 @@
"""
Tests for API endpoints grouped by tags
"""
import pytest
import logging
import sys
import os
from typing import Dict, Any, Set
# Use a direct import with the full path
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
# Define functions inline to avoid import issues
def get_all_endpoints(spec):
"""
Extract all endpoints from an OpenAPI spec
"""
endpoints = []
for path, path_item in spec['paths'].items():
for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
continue
endpoints.append({
'path': path,
'method': method.lower(),
'tags': operation.get('tags', []),
'operation_id': operation.get('operationId', ''),
'summary': operation.get('summary', '')
})
return endpoints
def get_all_tags(spec):
"""
Get all tags used in the API spec
"""
tags = set()
for path_item in spec['paths'].values():
for operation in path_item.values():
if isinstance(operation, dict) and 'tags' in operation:
tags.update(operation['tags'])
return tags
def extract_endpoints_by_tag(spec, tag):
"""
Extract all endpoints with a specific tag
"""
endpoints = []
for path, path_item in spec['paths'].items():
for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
continue
if tag in operation.get('tags', []):
endpoints.append({
'path': path,
'method': method.lower(),
'operation_id': operation.get('operationId', ''),
'summary': operation.get('summary', '')
})
return endpoints
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@pytest.fixture
def api_tags(api_spec: Dict[str, Any]) -> Set[str]:
"""
Get all tags from the API spec
Args:
api_spec: Loaded OpenAPI spec
Returns:
Set of tag names
"""
return get_all_tags(api_spec)
def test_api_has_tags(api_tags: Set[str]):
"""
Test that the API has defined tags
Args:
api_tags: Set of tags
"""
assert len(api_tags) > 0, "API spec should have at least one tag"
# Log the tags
logger.info(f"API spec has the following tags: {sorted(api_tags)}")
@pytest.mark.parametrize("tag", [
"workflow",
"image",
"model",
"node",
"system"
])
def test_core_tags_exist(api_tags: Set[str], tag: str):
"""
Test that core tags exist in the API spec
Args:
api_tags: Set of tags
tag: Tag to check
"""
assert tag in api_tags, f"API spec should have '{tag}' tag"
def test_workflow_tag_has_endpoints(api_spec: Dict[str, Any]):
"""
Test that the 'workflow' tag has appropriate endpoints
Args:
api_spec: Loaded OpenAPI spec
"""
endpoints = extract_endpoints_by_tag(api_spec, "workflow")
assert len(endpoints) > 0, "No endpoints found with 'workflow' tag"
# Check for key workflow endpoints
endpoint_paths = [e["path"] for e in endpoints]
assert "/prompt" in endpoint_paths, "Workflow tag should include /prompt endpoint"
# Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'workflow' tag:")
for e in endpoints:
logger.info(f" {e['method'].upper()} {e['path']}")
def test_image_tag_has_endpoints(api_spec: Dict[str, Any]):
"""
Test that the 'image' tag has appropriate endpoints
Args:
api_spec: Loaded OpenAPI spec
"""
endpoints = extract_endpoints_by_tag(api_spec, "image")
assert len(endpoints) > 0, "No endpoints found with 'image' tag"
# Check for key image endpoints
endpoint_paths = [e["path"] for e in endpoints]
assert "/upload/image" in endpoint_paths, "Image tag should include /upload/image endpoint"
assert "/view" in endpoint_paths, "Image tag should include /view endpoint"
# Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'image' tag:")
for e in endpoints:
logger.info(f" {e['method'].upper()} {e['path']}")
def test_model_tag_has_endpoints(api_spec: Dict[str, Any]):
"""
Test that the 'model' tag has appropriate endpoints
Args:
api_spec: Loaded OpenAPI spec
"""
endpoints = extract_endpoints_by_tag(api_spec, "model")
assert len(endpoints) > 0, "No endpoints found with 'model' tag"
# Check for key model endpoints
endpoint_paths = [e["path"] for e in endpoints]
assert "/models" in endpoint_paths, "Model tag should include /models endpoint"
# Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'model' tag:")
for e in endpoints:
logger.info(f" {e['method'].upper()} {e['path']}")
def test_node_tag_has_endpoints(api_spec: Dict[str, Any]):
"""
Test that the 'node' tag has appropriate endpoints
Args:
api_spec: Loaded OpenAPI spec
"""
endpoints = extract_endpoints_by_tag(api_spec, "node")
assert len(endpoints) > 0, "No endpoints found with 'node' tag"
# Check for key node endpoints
endpoint_paths = [e["path"] for e in endpoints]
assert "/object_info" in endpoint_paths, "Node tag should include /object_info endpoint"
# Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'node' tag:")
for e in endpoints:
logger.info(f" {e['method'].upper()} {e['path']}")
def test_system_tag_has_endpoints(api_spec: Dict[str, Any]):
"""
Test that the 'system' tag has appropriate endpoints
Args:
api_spec: Loaded OpenAPI spec
"""
endpoints = extract_endpoints_by_tag(api_spec, "system")
assert len(endpoints) > 0, "No endpoints found with 'system' tag"
# Check for key system endpoints
endpoint_paths = [e["path"] for e in endpoints]
assert "/system_stats" in endpoint_paths, "System tag should include /system_stats endpoint"
# Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'system' tag:")
for e in endpoints:
logger.info(f" {e['method'].upper()} {e['path']}")
def test_internal_tag_has_endpoints(api_spec: Dict[str, Any]):
"""
Test that the 'internal' tag has appropriate endpoints
Args:
api_spec: Loaded OpenAPI spec
"""
endpoints = extract_endpoints_by_tag(api_spec, "internal")
assert len(endpoints) > 0, "No endpoints found with 'internal' tag"
# Check for key internal endpoints
endpoint_paths = [e["path"] for e in endpoints]
assert "/internal/logs" in endpoint_paths, "Internal tag should include /internal/logs endpoint"
# Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'internal' tag:")
for e in endpoints:
logger.info(f" {e['method'].upper()} {e['path']}")
def test_operation_ids_match_tag(api_spec: Dict[str, Any]):
"""
Test that operation IDs follow a consistent pattern with their tag
Args:
api_spec: Loaded OpenAPI spec
"""
failures = []
for path, path_item in api_spec['paths'].items():
for method, operation in path_item.items():
if method in ['get', 'post', 'put', 'delete', 'patch']:
if 'operationId' in operation and 'tags' in operation and operation['tags']:
op_id = operation['operationId']
primary_tag = operation['tags'][0].lower()
# Check if operationId starts with primary tag prefix
# This is a common convention, but might need adjusting
if not (op_id.startswith(primary_tag) or
any(op_id.lower().startswith(f"{tag.lower()}") for tag in operation['tags'])):
failures.append({
'path': path,
'method': method,
'operationId': op_id,
'primary_tag': primary_tag
})
# Log failures for diagnosis but don't fail the test
# as this is a style/convention check
if failures:
logger.warning(f"Found {len(failures)} operationIds that don't align with their tags:")
for f in failures:
logger.warning(f" {f['method'].upper()} {f['path']} - operationId: {f['operationId']}, primary tag: {f['primary_tag']}")

View File

@@ -0,0 +1,240 @@
"""
Tests for endpoint existence and basic response codes
"""
import pytest
import requests
import logging
import sys
import os
from typing import Dict, Any, List
from urllib.parse import urljoin
# Use a direct import with the full path
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
# Define get_all_endpoints function inline to avoid import issues
def get_all_endpoints(spec):
"""
Extract all endpoints from an OpenAPI spec
Args:
spec: Parsed OpenAPI specification
Returns:
List of dicts with path, method, and tags for each endpoint
"""
endpoints = []
for path, path_item in spec['paths'].items():
for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
continue
endpoints.append({
'path': path,
'method': method.lower(),
'tags': operation.get('tags', []),
'operation_id': operation.get('operationId', ''),
'summary': operation.get('summary', '')
})
return endpoints
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@pytest.fixture
def all_endpoints(api_spec: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Get all endpoints from the API spec
Args:
api_spec: Loaded OpenAPI spec
Returns:
List of endpoint information
"""
return get_all_endpoints(api_spec)
def test_endpoints_exist(all_endpoints: List[Dict[str, Any]]):
"""
Test that endpoints are defined in the spec
Args:
all_endpoints: List of endpoint information
"""
# Simple check that we have endpoints defined
assert len(all_endpoints) > 0, "No endpoints defined in the OpenAPI spec"
# Log the endpoints for informational purposes
logger.info(f"Found {len(all_endpoints)} endpoints in the OpenAPI spec")
for endpoint in all_endpoints:
logger.info(f"{endpoint['method'].upper()} {endpoint['path']} - {endpoint['summary']}")
@pytest.mark.parametrize("endpoint_path", [
"/", # Root path
"/prompt", # Get prompt info
"/queue", # Get queue
"/models", # Get model types
"/object_info", # Get node info
"/system_stats" # Get system stats
])
def test_basic_get_endpoints(require_server, api_client, endpoint_path: str):
"""
Test that basic GET endpoints exist and respond
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
endpoint_path: Path to test
"""
url = api_client.get_url(endpoint_path) # type: ignore
try:
response = api_client.get(url)
# We're just checking that the endpoint exists and returns some kind of response
# Not necessarily a 200 status code
assert response.status_code not in [404, 405], f"Endpoint {endpoint_path} does not exist"
logger.info(f"Endpoint {endpoint_path} exists with status code {response.status_code}")
except requests.RequestException as e:
pytest.fail(f"Request to {endpoint_path} failed: {str(e)}")
def test_websocket_endpoint_exists(require_server, base_url: str):
"""
Test that the WebSocket endpoint exists
Args:
require_server: Fixture that skips if server is not available
base_url: Base server URL
"""
ws_url = urljoin(base_url, "/ws")
# For WebSocket, we can't use a normal GET request
# Instead, we make a HEAD request to check if the endpoint exists
try:
response = requests.head(ws_url)
# WebSocket endpoints often return a 400 Bad Request for HEAD requests
# but a 404 would indicate the endpoint doesn't exist
assert response.status_code != 404, "WebSocket endpoint /ws does not exist"
logger.info(f"WebSocket endpoint exists with status code {response.status_code}")
except requests.RequestException as e:
pytest.fail(f"Request to WebSocket endpoint failed: {str(e)}")
def test_api_models_folder_endpoint(require_server, api_client):
"""
Test that the /models/{folder} endpoint exists and responds
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
"""
# First get available model types
models_url = api_client.get_url("/models") # type: ignore
try:
models_response = api_client.get(models_url)
assert models_response.status_code == 200, "Failed to get model types"
model_types = models_response.json()
# Skip if no model types available
if not model_types:
pytest.skip("No model types available to test")
# Test with the first model type
model_type = model_types[0]
models_folder_url = api_client.get_url(f"/models/{model_type}") # type: ignore
folder_response = api_client.get(models_folder_url)
# We're just checking that the endpoint exists
assert folder_response.status_code != 404, f"Endpoint /models/{model_type} does not exist"
logger.info(f"Endpoint /models/{model_type} exists with status code {folder_response.status_code}")
except requests.RequestException as e:
pytest.fail(f"Request failed: {str(e)}")
except (ValueError, KeyError, IndexError) as e:
pytest.fail(f"Failed to process response: {str(e)}")
def test_api_object_info_node_endpoint(require_server, api_client):
"""
Test that the /object_info/{node_class} endpoint exists and responds
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
"""
# First get available node classes
objects_url = api_client.get_url("/object_info") # type: ignore
try:
objects_response = api_client.get(objects_url)
assert objects_response.status_code == 200, "Failed to get object info"
node_classes = objects_response.json()
# Skip if no node classes available
if not node_classes:
pytest.skip("No node classes available to test")
# Test with the first node class
node_class = next(iter(node_classes.keys()))
node_url = api_client.get_url(f"/object_info/{node_class}") # type: ignore
node_response = api_client.get(node_url)
# We're just checking that the endpoint exists
assert node_response.status_code != 404, f"Endpoint /object_info/{node_class} does not exist"
logger.info(f"Endpoint /object_info/{node_class} exists with status code {node_response.status_code}")
except requests.RequestException as e:
pytest.fail(f"Request failed: {str(e)}")
except (ValueError, KeyError, StopIteration) as e:
pytest.fail(f"Failed to process response: {str(e)}")
def test_internal_endpoints_exist(require_server, api_client):
"""
Test that internal endpoints exist
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
"""
internal_endpoints = [
"/internal/logs",
"/internal/logs/raw",
"/internal/folder_paths",
"/internal/files/output"
]
for endpoint in internal_endpoints:
url = api_client.get_url(endpoint) # type: ignore
try:
response = api_client.get(url)
# We're just checking that the endpoint exists
assert response.status_code != 404, f"Endpoint {endpoint} does not exist"
logger.info(f"Endpoint {endpoint} exists with status code {response.status_code}")
except requests.RequestException as e:
logger.warning(f"Request to {endpoint} failed: {str(e)}")
# Don't fail the test as internal endpoints might be restricted

View File

@@ -0,0 +1,440 @@
"""
Tests for validating API responses against OpenAPI schema
"""
import pytest
import requests
import logging
import sys
import os
import json
from typing import Dict, Any
# Use a direct import with the full path
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
# Define validation functions inline to avoid import issues
def get_endpoint_schema(
spec,
path,
method,
status_code = '200'
):
"""
Extract response schema for a specific endpoint from OpenAPI spec
"""
method = method.lower()
# Handle path not found
if path not in spec['paths']:
return None
# Handle method not found
if method not in spec['paths'][path]:
return None
# Handle status code not found
responses = spec['paths'][path][method].get('responses', {})
if status_code not in responses:
return None
# Handle no content defined
if 'content' not in responses[status_code]:
return None
# Get schema from first content type
content_types = responses[status_code]['content']
first_content_type = next(iter(content_types))
if 'schema' not in content_types[first_content_type]:
return None
return content_types[first_content_type]['schema']
def resolve_schema_refs(schema, spec):
"""
Resolve $ref references in a schema
"""
if not isinstance(schema, dict):
return schema
result = {}
for key, value in schema.items():
if key == '$ref' and isinstance(value, str) and value.startswith('#/'):
# Handle reference
ref_path = value[2:].split('/')
ref_value = spec
for path_part in ref_path:
ref_value = ref_value.get(path_part, {})
# Recursively resolve any refs in the referenced schema
ref_value = resolve_schema_refs(ref_value, spec)
result.update(ref_value)
elif isinstance(value, dict):
# Recursively resolve refs in nested dictionaries
result[key] = resolve_schema_refs(value, spec)
elif isinstance(value, list):
# Recursively resolve refs in list items
result[key] = [
resolve_schema_refs(item, spec) if isinstance(item, dict) else item
for item in value
]
else:
# Pass through other values
result[key] = value
return result
def validate_response(
response_data,
spec,
path,
method,
status_code = '200'
):
"""
Validate a response against the OpenAPI schema
"""
schema = get_endpoint_schema(spec, path, method, status_code)
if schema is None:
return {
'valid': False,
'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"]
}
# Resolve any $ref in the schema
resolved_schema = resolve_schema_refs(schema, spec)
try:
import jsonschema
jsonschema.validate(instance=response_data, schema=resolved_schema)
return {'valid': True, 'errors': []}
except jsonschema.exceptions.ValidationError as e:
# Extract more detailed error information
path = ".".join(str(p) for p in e.path) if e.path else "root"
instance = e.instance if not isinstance(e.instance, dict) else "..."
schema_path = ".".join(str(p) for p in e.schema_path) if e.schema_path else "unknown"
detailed_error = (
f"Validation error at path: {path}\n"
f"Schema path: {schema_path}\n"
f"Error message: {e.message}\n"
f"Failed instance: {instance}\n"
)
return {'valid': False, 'errors': [detailed_error]}
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@pytest.mark.parametrize("endpoint_path,method", [
("/system_stats", "get"),
("/prompt", "get"),
("/queue", "get"),
("/models", "get"),
("/embeddings", "get")
])
def test_response_schema_validation(
require_server,
api_client,
api_spec: Dict[str, Any],
endpoint_path: str,
method: str
):
"""
Test that API responses match the defined schema
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
api_spec: Loaded OpenAPI spec
endpoint_path: Path to test
method: HTTP method to test
"""
url = api_client.get_url(endpoint_path) # type: ignore
# Skip if no schema defined
schema = get_endpoint_schema(api_spec, endpoint_path, method)
if not schema:
pytest.skip(f"No schema defined for {method.upper()} {endpoint_path}")
try:
if method.lower() == "get":
response = api_client.get(url)
else:
pytest.skip(f"Method {method} not implemented for automated testing")
return
# Skip if response is not 200
if response.status_code != 200:
pytest.skip(f"Endpoint {endpoint_path} returned status {response.status_code}")
return
# Skip if response is not JSON
try:
response_data = response.json()
except ValueError:
pytest.skip(f"Endpoint {endpoint_path} did not return valid JSON")
return
# Validate the response
validation_result = validate_response(
response_data,
api_spec,
endpoint_path,
method
)
if validation_result['valid']:
logger.info(f"Response from {method.upper()} {endpoint_path} matches schema")
else:
for error in validation_result['errors']:
logger.error(f"Validation error for {method.upper()} {endpoint_path}: {error}")
assert validation_result['valid'], f"Response from {method.upper()} {endpoint_path} does not match schema"
except requests.RequestException as e:
pytest.fail(f"Request to {endpoint_path} failed: {str(e)}")
def test_system_stats_response(require_server, api_client, api_spec: Dict[str, Any]):
"""
Test the system_stats endpoint response in detail
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
api_spec: Loaded OpenAPI spec
"""
url = api_client.get_url("/system_stats") # type: ignore
try:
response = api_client.get(url)
assert response.status_code == 200, "Failed to get system stats"
# Parse response
stats = response.json()
# Validate high-level structure
assert 'system' in stats, "Response missing 'system' field"
assert 'devices' in stats, "Response missing 'devices' field"
# Validate system fields
system = stats['system']
assert 'os' in system, "System missing 'os' field"
assert 'ram_total' in system, "System missing 'ram_total' field"
assert 'ram_free' in system, "System missing 'ram_free' field"
assert 'comfyui_version' in system, "System missing 'comfyui_version' field"
# Validate devices fields
devices = stats['devices']
assert isinstance(devices, list), "Devices should be a list"
if devices:
device = devices[0]
assert 'name' in device, "Device missing 'name' field"
assert 'type' in device, "Device missing 'type' field"
assert 'vram_total' in device, "Device missing 'vram_total' field"
assert 'vram_free' in device, "Device missing 'vram_free' field"
# Perform schema validation
validation_result = validate_response(
stats,
api_spec,
"/system_stats",
"get"
)
# Print detailed error if validation fails
if not validation_result['valid']:
for error in validation_result['errors']:
logger.error(f"Validation error for /system_stats: {error}")
# Print schema details for debugging
schema = get_endpoint_schema(api_spec, "/system_stats", "get")
if schema:
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
# Print sample of the response
logger.error(f"Response:\n{json.dumps(stats, indent=2)}")
assert validation_result['valid'], "System stats response does not match schema"
except requests.RequestException as e:
pytest.fail(f"Request to /system_stats failed: {str(e)}")
def test_models_listing_response(require_server, api_client, api_spec: Dict[str, Any]):
"""
Test the models endpoint response
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
api_spec: Loaded OpenAPI spec
"""
url = api_client.get_url("/models") # type: ignore
try:
response = api_client.get(url)
assert response.status_code == 200, "Failed to get models"
# Parse response
models = response.json()
# Validate it's a list
assert isinstance(models, list), "Models response should be a list"
# Each item should be a string
for model in models:
assert isinstance(model, str), "Each model type should be a string"
# Perform schema validation
validation_result = validate_response(
models,
api_spec,
"/models",
"get"
)
# Print detailed error if validation fails
if not validation_result['valid']:
for error in validation_result['errors']:
logger.error(f"Validation error for /models: {error}")
# Print schema details for debugging
schema = get_endpoint_schema(api_spec, "/models", "get")
if schema:
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
# Print response
sample_models = models[:5] if isinstance(models, list) else models
logger.error(f"Models response:\n{json.dumps(sample_models, indent=2)}")
assert validation_result['valid'], "Models response does not match schema"
except requests.RequestException as e:
pytest.fail(f"Request to /models failed: {str(e)}")
def test_object_info_response(require_server, api_client, api_spec: Dict[str, Any]):
"""
Test the object_info endpoint response
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
api_spec: Loaded OpenAPI spec
"""
url = api_client.get_url("/object_info") # type: ignore
try:
response = api_client.get(url)
assert response.status_code == 200, "Failed to get object info"
# Parse response
objects = response.json()
# Validate it's an object
assert isinstance(objects, dict), "Object info response should be an object"
# Check if we have any objects
if objects:
# Get the first object
first_obj_name = next(iter(objects.keys()))
first_obj = objects[first_obj_name]
# Validate first object has required fields
assert 'input' in first_obj, "Object missing 'input' field"
assert 'output' in first_obj, "Object missing 'output' field"
assert 'name' in first_obj, "Object missing 'name' field"
# Perform schema validation
validation_result = validate_response(
objects,
api_spec,
"/object_info",
"get"
)
# Print detailed error if validation fails
if not validation_result['valid']:
for error in validation_result['errors']:
logger.error(f"Validation error for /object_info: {error}")
# Print schema details for debugging
schema = get_endpoint_schema(api_spec, "/object_info", "get")
if schema:
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
# Also print a small sample of the response
sample = dict(list(objects.items())[:1]) if objects else {}
logger.error(f"Sample response:\n{json.dumps(sample, indent=2)}")
assert validation_result['valid'], "Object info response does not match schema"
except requests.RequestException as e:
pytest.fail(f"Request to /object_info failed: {str(e)}")
except (KeyError, StopIteration) as e:
pytest.fail(f"Failed to process response: {str(e)}")
def test_queue_response(require_server, api_client, api_spec: Dict[str, Any]):
"""
Test the queue endpoint response
Args:
require_server: Fixture that skips if server is not available
api_client: API client fixture
api_spec: Loaded OpenAPI spec
"""
url = api_client.get_url("/queue") # type: ignore
try:
response = api_client.get(url)
assert response.status_code == 200, "Failed to get queue"
# Parse response
queue = response.json()
# Validate structure
assert 'queue_running' in queue, "Queue missing 'queue_running' field"
assert 'queue_pending' in queue, "Queue missing 'queue_pending' field"
# Each should be a list
assert isinstance(queue['queue_running'], list), "queue_running should be a list"
assert isinstance(queue['queue_pending'], list), "queue_pending should be a list"
# Perform schema validation
validation_result = validate_response(
queue,
api_spec,
"/queue",
"get"
)
# Print detailed error if validation fails
if not validation_result['valid']:
for error in validation_result['errors']:
logger.error(f"Validation error for /queue: {error}")
# Print schema details for debugging
schema = get_endpoint_schema(api_spec, "/queue", "get")
if schema:
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
# Print response
logger.error(f"Queue response:\n{json.dumps(queue, indent=2)}")
assert validation_result['valid'], "Queue response does not match schema"
except requests.RequestException as e:
pytest.fail(f"Request to /queue failed: {str(e)}")

View File

@@ -0,0 +1,144 @@
"""
Tests for validating the OpenAPI specification
"""
import pytest
from openapi_spec_validator import validate_spec
from openapi_spec_validator.exceptions import OpenAPISpecValidatorError
from typing import Dict, Any
def test_openapi_spec_is_valid(api_spec: Dict[str, Any]):
"""
Test that the OpenAPI specification is valid
Args:
api_spec: Loaded OpenAPI spec
"""
try:
validate_spec(api_spec)
except OpenAPISpecValidatorError as e:
pytest.fail(f"OpenAPI spec validation failed: {str(e)}")
def test_spec_has_info(api_spec: Dict[str, Any]):
"""
Test that the OpenAPI spec has the required info section
Args:
api_spec: Loaded OpenAPI spec
"""
assert 'info' in api_spec, "Spec must have info section"
assert 'title' in api_spec['info'], "Info must have title"
assert 'version' in api_spec['info'], "Info must have version"
def test_spec_has_paths(api_spec: Dict[str, Any]):
"""
Test that the OpenAPI spec has paths defined
Args:
api_spec: Loaded OpenAPI spec
"""
assert 'paths' in api_spec, "Spec must have paths section"
assert len(api_spec['paths']) > 0, "Spec must have at least one path"
def test_spec_has_components(api_spec: Dict[str, Any]):
"""
Test that the OpenAPI spec has components defined
Args:
api_spec: Loaded OpenAPI spec
"""
assert 'components' in api_spec, "Spec must have components section"
assert 'schemas' in api_spec['components'], "Components must have schemas"
def test_workflow_endpoints_exist(api_spec: Dict[str, Any]):
"""
Test that core workflow endpoints are defined
Args:
api_spec: Loaded OpenAPI spec
"""
assert '/prompt' in api_spec['paths'], "Spec must define /prompt endpoint"
assert 'post' in api_spec['paths']['/prompt'], "Spec must define POST /prompt"
assert 'get' in api_spec['paths']['/prompt'], "Spec must define GET /prompt"
def test_image_endpoints_exist(api_spec: Dict[str, Any]):
"""
Test that core image endpoints are defined
Args:
api_spec: Loaded OpenAPI spec
"""
assert '/upload/image' in api_spec['paths'], "Spec must define /upload/image endpoint"
assert '/view' in api_spec['paths'], "Spec must define /view endpoint"
def test_model_endpoints_exist(api_spec: Dict[str, Any]):
"""
Test that core model endpoints are defined
Args:
api_spec: Loaded OpenAPI spec
"""
assert '/models' in api_spec['paths'], "Spec must define /models endpoint"
assert '/models/{folder}' in api_spec['paths'], "Spec must define /models/{folder} endpoint"
def test_operation_ids_are_unique(api_spec: Dict[str, Any]):
"""
Test that all operationIds are unique
Args:
api_spec: Loaded OpenAPI spec
"""
operation_ids = []
for path, path_item in api_spec['paths'].items():
for method, operation in path_item.items():
if method in ['get', 'post', 'put', 'delete', 'patch']:
if 'operationId' in operation:
operation_ids.append(operation['operationId'])
# Check for duplicates
duplicates = set([op_id for op_id in operation_ids if operation_ids.count(op_id) > 1])
assert len(duplicates) == 0, f"Found duplicate operationIds: {duplicates}"
def test_all_endpoints_have_operation_ids(api_spec: Dict[str, Any]):
"""
Test that all endpoints have operationIds
Args:
api_spec: Loaded OpenAPI spec
"""
missing = []
for path, path_item in api_spec['paths'].items():
for method, operation in path_item.items():
if method in ['get', 'post', 'put', 'delete', 'patch']:
if 'operationId' not in operation:
missing.append(f"{method.upper()} {path}")
assert len(missing) == 0, f"Found endpoints without operationIds: {missing}"
def test_all_endpoints_have_tags(api_spec: Dict[str, Any]):
"""
Test that all endpoints have tags
Args:
api_spec: Loaded OpenAPI spec
"""
missing = []
for path, path_item in api_spec['paths'].items():
for method, operation in path_item.items():
if method in ['get', 'post', 'put', 'delete', 'patch']:
if 'tags' not in operation or not operation['tags']:
missing.append(f"{method.upper()} {path}")
assert len(missing) == 0, f"Found endpoints without tags: {missing}"

View File

@@ -0,0 +1,157 @@
"""
Utilities for working with OpenAPI schemas
"""
from typing import Any, Dict, List, Optional, Set, Tuple
def extract_required_parameters(
spec: Dict[str, Any],
path: str,
method: str
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Extract required parameters for a specific endpoint
Args:
spec: Parsed OpenAPI specification
path: API path (e.g., '/prompt')
method: HTTP method (e.g., 'get', 'post')
Returns:
Tuple of (path_params, query_params) containing required parameters
"""
method = method.lower()
path_params = []
query_params = []
# Handle path not found
if path not in spec['paths']:
return path_params, query_params
# Handle method not found
if method not in spec['paths'][path]:
return path_params, query_params
# Get parameters
params = spec['paths'][path][method].get('parameters', [])
for param in params:
if param.get('required', False):
if param.get('in') == 'path':
path_params.append(param)
elif param.get('in') == 'query':
query_params.append(param)
return path_params, query_params
def get_request_body_schema(
spec: Dict[str, Any],
path: str,
method: str
) -> Optional[Dict[str, Any]]:
"""
Get request body schema for a specific endpoint
Args:
spec: Parsed OpenAPI specification
path: API path (e.g., '/prompt')
method: HTTP method (e.g., 'get', 'post')
Returns:
Request body schema or None if not found
"""
method = method.lower()
# Handle path not found
if path not in spec['paths']:
return None
# Handle method not found
if method not in spec['paths'][path]:
return None
# Handle no request body
request_body = spec['paths'][path][method].get('requestBody', {})
if not request_body or 'content' not in request_body:
return None
# Get schema from first content type
content_types = request_body['content']
first_content_type = next(iter(content_types))
if 'schema' not in content_types[first_content_type]:
return None
return content_types[first_content_type]['schema']
def extract_endpoints_by_tag(spec: Dict[str, Any], tag: str) -> List[Dict[str, Any]]:
"""
Extract all endpoints with a specific tag
Args:
spec: Parsed OpenAPI specification
tag: Tag to filter by
Returns:
List of endpoint details
"""
endpoints = []
for path, path_item in spec['paths'].items():
for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
continue
if tag in operation.get('tags', []):
endpoints.append({
'path': path,
'method': method.lower(),
'operation_id': operation.get('operationId', ''),
'summary': operation.get('summary', '')
})
return endpoints
def get_all_tags(spec: Dict[str, Any]) -> Set[str]:
"""
Get all tags used in the API spec
Args:
spec: Parsed OpenAPI specification
Returns:
Set of tag names
"""
tags = set()
for path_item in spec['paths'].values():
for operation in path_item.values():
if isinstance(operation, dict) and 'tags' in operation:
tags.update(operation['tags'])
return tags
def get_schema_examples(spec: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract all examples from component schemas
Args:
spec: Parsed OpenAPI specification
Returns:
Dict mapping schema names to examples
"""
examples = {}
if 'components' not in spec or 'schemas' not in spec['components']:
return examples
for name, schema in spec['components']['schemas'].items():
if 'example' in schema:
examples[name] = schema['example']
return examples

View File

@@ -0,0 +1,178 @@
"""
Utilities for API response validation against OpenAPI spec
"""
import yaml
import jsonschema
from typing import Any, Dict, List, Optional, Union
def load_openapi_spec(spec_path: str) -> Dict[str, Any]:
"""
Load the OpenAPI specification from a YAML file
Args:
spec_path: Path to the OpenAPI specification file
Returns:
Dict containing the parsed OpenAPI spec
"""
with open(spec_path, 'r') as f:
return yaml.safe_load(f)
def get_endpoint_schema(
spec: Dict[str, Any],
path: str,
method: str,
status_code: str = '200'
) -> Optional[Dict[str, Any]]:
"""
Extract response schema for a specific endpoint from OpenAPI spec
Args:
spec: Parsed OpenAPI specification
path: API path (e.g., '/prompt')
method: HTTP method (e.g., 'get', 'post')
status_code: HTTP status code to get schema for
Returns:
Schema dict or None if not found
"""
method = method.lower()
# Handle path not found
if path not in spec['paths']:
return None
# Handle method not found
if method not in spec['paths'][path]:
return None
# Handle status code not found
responses = spec['paths'][path][method].get('responses', {})
if status_code not in responses:
return None
# Handle no content defined
if 'content' not in responses[status_code]:
return None
# Get schema from first content type
content_types = responses[status_code]['content']
first_content_type = next(iter(content_types))
if 'schema' not in content_types[first_content_type]:
return None
return content_types[first_content_type]['schema']
def resolve_schema_refs(schema: Dict[str, Any], spec: Dict[str, Any]) -> Dict[str, Any]:
"""
Resolve $ref references in a schema
Args:
schema: Schema that may contain references
spec: Full OpenAPI spec with component definitions
Returns:
Schema with references resolved
"""
if not isinstance(schema, dict):
return schema
result = {}
for key, value in schema.items():
if key == '$ref' and isinstance(value, str) and value.startswith('#/'):
# Handle reference
ref_path = value[2:].split('/')
ref_value = spec
for path_part in ref_path:
ref_value = ref_value.get(path_part, {})
# Recursively resolve any refs in the referenced schema
ref_value = resolve_schema_refs(ref_value, spec)
result.update(ref_value)
elif isinstance(value, dict):
# Recursively resolve refs in nested dictionaries
result[key] = resolve_schema_refs(value, spec)
elif isinstance(value, list):
# Recursively resolve refs in list items
result[key] = [
resolve_schema_refs(item, spec) if isinstance(item, dict) else item
for item in value
]
else:
# Pass through other values
result[key] = value
return result
def validate_response(
response_data: Union[Dict[str, Any], List[Any]],
spec: Dict[str, Any],
path: str,
method: str,
status_code: str = '200'
) -> Dict[str, Any]:
"""
Validate a response against the OpenAPI schema
Args:
response_data: Response data to validate
spec: Parsed OpenAPI specification
path: API path (e.g., '/prompt')
method: HTTP method (e.g., 'get', 'post')
status_code: HTTP status code to validate against
Returns:
Dict with validation result containing:
- valid: bool indicating if validation passed
- errors: List of validation errors if any
"""
schema = get_endpoint_schema(spec, path, method, status_code)
if schema is None:
return {
'valid': False,
'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"]
}
# Resolve any $ref in the schema
resolved_schema = resolve_schema_refs(schema, spec)
try:
jsonschema.validate(instance=response_data, schema=resolved_schema)
return {'valid': True, 'errors': []}
except jsonschema.exceptions.ValidationError as e:
return {'valid': False, 'errors': [str(e)]}
def get_all_endpoints(spec: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Extract all endpoints from an OpenAPI spec
Args:
spec: Parsed OpenAPI specification
Returns:
List of dicts with path, method, and tags for each endpoint
"""
endpoints = []
for path, path_item in spec['paths'].items():
for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
continue
endpoints.append({
'path': path,
'method': method.lower(),
'tags': operation.get('tags', []),
'operation_id': operation.get('operationId', ''),
'summary': operation.get('summary', '')
})
return endpoints

View File

@@ -0,0 +1,239 @@
import pytest
import torch
import tempfile
import os
import av
import io
from fractions import Fraction
from comfy_api.input_impl.video_types import VideoFromFile, VideoFromComponents
from comfy_api.util.video_types import VideoComponents
from comfy_api.input.basic_types import AudioInput
from av.error import InvalidDataError
EPSILON = 0.0001
@pytest.fixture
def sample_images():
"""3-frame 2x2 RGB video tensor"""
return torch.rand(3, 2, 2, 3)
@pytest.fixture
def sample_audio():
"""Stereo audio with 44.1kHz sample rate"""
return AudioInput(
{
"waveform": torch.rand(1, 2, 1000),
"sample_rate": 44100,
}
)
@pytest.fixture
def video_components(sample_images, sample_audio):
"""VideoComponents with images, audio, and metadata"""
return VideoComponents(
images=sample_images,
audio=sample_audio,
frame_rate=Fraction(30),
metadata={"test": "metadata"},
)
def create_test_video(width=4, height=4, frames=3, fps=30):
"""Helper to create a temporary video file"""
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
with av.open(tmp.name, mode="w") as container:
stream = container.add_stream("h264", rate=fps)
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
for i in range(frames):
frame = av.VideoFrame.from_ndarray(
torch.ones(height, width, 3, dtype=torch.uint8).numpy() * (i * 85),
format="rgb24",
)
frame = frame.reformat(format="yuv420p")
packet = stream.encode(frame)
container.mux(packet)
# Flush
packet = stream.encode(None)
container.mux(packet)
return tmp.name
@pytest.fixture
def simple_video_file():
"""4x4 video with 3 frames at 30fps"""
file_path = create_test_video()
yield file_path
os.unlink(file_path)
def test_video_from_components_get_duration(video_components):
"""Duration calculated correctly from frame count and frame rate"""
video = VideoFromComponents(video_components)
duration = video.get_duration()
expected_duration = 3.0 / 30.0
assert duration == pytest.approx(expected_duration)
def test_video_from_components_get_duration_different_frame_rates(sample_images):
"""Duration correct for different frame rates including fractional"""
# Test with 60 fps
components_60fps = VideoComponents(images=sample_images, frame_rate=Fraction(60))
video_60fps = VideoFromComponents(components_60fps)
assert video_60fps.get_duration() == pytest.approx(3.0 / 60.0)
# Test with fractional frame rate (23.976fps)
components_frac = VideoComponents(
images=sample_images, frame_rate=Fraction(24000, 1001)
)
video_frac = VideoFromComponents(components_frac)
expected_frac = 3.0 / (24000.0 / 1001.0)
assert video_frac.get_duration() == pytest.approx(expected_frac)
def test_video_from_components_get_duration_empty_video():
"""Duration is zero for empty video"""
empty_components = VideoComponents(
images=torch.zeros(0, 2, 2, 3), frame_rate=Fraction(30)
)
video = VideoFromComponents(empty_components)
assert video.get_duration() == 0.0
def test_video_from_components_get_dimensions(video_components):
"""Dimensions returned correctly from image tensor shape"""
video = VideoFromComponents(video_components)
width, height = video.get_dimensions()
assert width == 2
assert height == 2
def test_video_from_file_get_duration(simple_video_file):
"""Duration extracted from file metadata"""
video = VideoFromFile(simple_video_file)
duration = video.get_duration()
assert duration == pytest.approx(0.1, abs=0.01)
def test_video_from_file_get_dimensions(simple_video_file):
"""Dimensions read from stream without decoding frames"""
video = VideoFromFile(simple_video_file)
width, height = video.get_dimensions()
assert width == 4
assert height == 4
def test_video_from_file_bytesio_input():
"""VideoFromFile works with BytesIO input"""
buffer = io.BytesIO()
with av.open(buffer, mode="w", format="mp4") as container:
stream = container.add_stream("h264", rate=30)
stream.width = 2
stream.height = 2
stream.pix_fmt = "yuv420p"
frame = av.VideoFrame.from_ndarray(
torch.zeros(2, 2, 3, dtype=torch.uint8).numpy(), format="rgb24"
)
frame = frame.reformat(format="yuv420p")
packet = stream.encode(frame)
container.mux(packet)
packet = stream.encode(None)
container.mux(packet)
buffer.seek(0)
video = VideoFromFile(buffer)
assert video.get_dimensions() == (2, 2)
assert video.get_duration() == pytest.approx(1 / 30, abs=0.01)
def test_video_from_file_invalid_file_error():
"""InvalidDataError raised for non-video files"""
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp:
tmp.write(b"not a video file")
tmp.flush()
tmp_name = tmp.name
try:
with pytest.raises(InvalidDataError):
video = VideoFromFile(tmp_name)
video.get_dimensions()
finally:
os.unlink(tmp_name)
def test_video_from_file_audio_only_error():
"""ValueError raised for audio-only files"""
with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp:
tmp_name = tmp.name
try:
with av.open(tmp_name, mode="w") as container:
stream = container.add_stream("aac", rate=44100)
stream.sample_rate = 44100
stream.format = "fltp"
audio_data = torch.zeros(1, 1024).numpy()
audio_frame = av.AudioFrame.from_ndarray(
audio_data, format="fltp", layout="mono"
)
audio_frame.sample_rate = 44100
audio_frame.pts = 0
packet = stream.encode(audio_frame)
container.mux(packet)
for packet in stream.encode(None):
container.mux(packet)
with pytest.raises(ValueError, match="No video stream found"):
video = VideoFromFile(tmp_name)
video.get_dimensions()
finally:
os.unlink(tmp_name)
def test_single_frame_video():
"""Single frame video has correct duration"""
components = VideoComponents(
images=torch.rand(1, 10, 10, 3), frame_rate=Fraction(1)
)
video = VideoFromComponents(components)
assert video.get_duration() == 1.0
@pytest.mark.parametrize(
"frame_rate,expected_fps",
[
(Fraction(24000, 1001), 24000 / 1001),
(Fraction(30000, 1001), 30000 / 1001),
(Fraction(25, 1), 25.0),
(Fraction(50, 2), 25.0),
],
)
def test_fractional_frame_rates(frame_rate, expected_fps):
"""Duration calculated correctly for various fractional frame rates"""
components = VideoComponents(images=torch.rand(100, 4, 4, 3), frame_rate=frame_rate)
video = VideoFromComponents(components)
duration = video.get_duration()
expected_duration = 100.0 / expected_fps
assert duration == pytest.approx(expected_duration)
def test_duration_consistency(video_components):
"""get_duration() consistent with manual calculation from components"""
video = VideoFromComponents(video_components)
duration = video.get_duration()
components = video.get_components()
manual_duration = float(components.images.shape[0] / components.frame_rate)
assert duration == pytest.approx(manual_duration)