Compare commits
59 Commits
desktop-re
...
openapi-sp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d65ad9940b | ||
|
|
e8a92e4c9b | ||
|
|
fa9688b1fb | ||
|
|
87f9130778 | ||
|
|
7e84bf5373 | ||
|
|
4f3b50ba51 | ||
|
|
e930a387d6 | ||
|
|
d8e5662822 | ||
|
|
3d44a09812 | ||
|
|
62690eddec | ||
|
|
05eb10b43a | ||
|
|
f5e4e976f4 | ||
|
|
aee2908d03 | ||
|
|
dc46db7aa4 | ||
|
|
7046983d95 | ||
|
|
1c2d45d2b5 | ||
|
|
c820ef950d | ||
|
|
6a2e4bb9e0 | ||
|
|
f1f9763b4c | ||
|
|
08368f8e00 | ||
|
|
f3ff5c40db | ||
|
|
98ff01e148 | ||
|
|
bab836d88d | ||
|
|
4a9014e201 | ||
|
|
8a7c894d54 | ||
|
|
a814f2e8cc | ||
|
|
481732a0ed | ||
|
|
2156ce9453 | ||
|
|
4136502b7a | ||
|
|
9ad287ff20 | ||
|
|
f5cacaeb14 | ||
|
|
b7ed5f57bd | ||
|
|
b4abca828e | ||
|
|
158419f3a0 | ||
|
|
640c47e7de | ||
|
|
31e9e36c94 | ||
|
|
577de83ca9 | ||
|
|
3535909eb8 | ||
|
|
235d3901fc | ||
|
|
d42613686f | ||
|
|
1b3bf0a5da | ||
|
|
ae60b150e5 | ||
|
|
42da274717 | ||
|
|
28f178a840 | ||
|
|
8ab15c863c | ||
|
|
924d771e18 | ||
|
|
02a1b01aad | ||
|
|
a692c3cca4 | ||
|
|
5d3cc85e13 | ||
|
|
c7c025b8d1 | ||
|
|
fd08e39588 | ||
|
|
56b6ee6754 | ||
|
|
cc33cd3422 | ||
|
|
b9980592c4 | ||
|
|
16417b40d9 | ||
|
|
271c9c5b9e | ||
|
|
a4e679765e | ||
|
|
0cf2e46b17 | ||
|
|
094e9ef126 |
49
.github/workflows/openapi-validation.yml
vendored
Normal file
49
.github/workflows/openapi-validation.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
name: Validate OpenAPI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
paths:
|
||||
- 'openapi.yaml'
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
paths:
|
||||
- 'openapi.yaml'
|
||||
|
||||
jobs:
|
||||
openapi-check:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
# Service containers to run with `runner-job`
|
||||
services:
|
||||
# Label used to access the service container
|
||||
swagger-editor:
|
||||
# Docker Hub image
|
||||
image: swaggerapi/swagger-editor
|
||||
ports:
|
||||
# Maps port 8080 on service container to the host 80
|
||||
- 80:8080
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Validate OpenAPI definition
|
||||
uses: swaggerexpert/swagger-editor-validate@v1
|
||||
with:
|
||||
definition-file: openapi.yaml
|
||||
swagger-editor-url: http://localhost/
|
||||
default-timeout: 20000
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
pip install -r tests-api/requirements.txt
|
||||
|
||||
- name: Run OpenAPI spec validation tests
|
||||
run: |
|
||||
pytest tests-api/test_spec_validation.py -v
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -21,6 +21,5 @@ venv/
|
||||
*.log
|
||||
web_custom_versions/
|
||||
.DS_Store
|
||||
openapi.yaml
|
||||
filtered-openapi.yaml
|
||||
uv.lock
|
||||
|
||||
13
README.md
13
README.md
@@ -69,9 +69,11 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- Audio Models
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- 3D Models
|
||||
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
@@ -108,7 +110,6 @@ ComfyUI follows a weekly release cycle every Friday, with three interconnected r
|
||||
|
||||
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
||||
- Builds a new release using the latest stable core version
|
||||
- Version numbers match the core release (e.g., Desktop v1.7.0 uses Core v1.7.0)
|
||||
|
||||
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
|
||||
- Weekly frontend updates are merged into the core repository
|
||||
@@ -196,11 +197,11 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2.4```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.3 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
@@ -300,7 +301,7 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt
|
||||
|
||||
### AMD ROCm Tips
|
||||
|
||||
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
|
||||
You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default.
|
||||
|
||||
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
|
||||
|
||||
|
||||
@@ -142,12 +142,15 @@ class PerformanceFeature(enum.Enum):
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
|
||||
@@ -235,7 +235,7 @@ class ComfyNodeABC(ABC):
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
API_NODE: Optional[bool]
|
||||
"""Flags a node as an API node."""
|
||||
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
||||
@@ -1277,6 +1277,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
phi1_fn = lambda t: torch.expm1(t) / t
|
||||
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
|
||||
|
||||
old_sigma_down = None
|
||||
old_denoised = None
|
||||
uncond_denoised = None
|
||||
def post_cfg_function(args):
|
||||
@@ -1304,9 +1305,9 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Second order multistep method in https://arxiv.org/pdf/2308.02157
|
||||
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1])
|
||||
t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
|
||||
h = t_next - t
|
||||
c2 = (t_prev - t) / h
|
||||
c2 = (t_prev - t_old) / h
|
||||
|
||||
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
|
||||
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
|
||||
@@ -1326,6 +1327,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
old_denoised = uncond_denoised
|
||||
else:
|
||||
old_denoised = denoised
|
||||
old_sigma_down = sigma_down
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -466,3 +466,7 @@ class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 1.0188137142395404
|
||||
|
||||
class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
761
comfy/ldm/ace/attention.py
Normal file
761
comfy/ldm/ace/attention.py
Normal file
@@ -0,0 +1,761 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/attention.py
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple, Union, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
kv_heads: Optional[int] = None,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
added_proj_bias: Optional[bool] = True,
|
||||
out_bias: bool = True,
|
||||
scale_qk: bool = True,
|
||||
only_cross_attention: bool = False,
|
||||
eps: float = 1e-5,
|
||||
rescale_output_factor: float = 1.0,
|
||||
residual_connection: bool = False,
|
||||
processor=None,
|
||||
out_dim: int = None,
|
||||
out_context_dim: int = None,
|
||||
context_pre_only=None,
|
||||
pre_only=False,
|
||||
elementwise_affine: bool = True,
|
||||
is_causal: bool = False,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
||||
self.query_dim = query_dim
|
||||
self.use_bias = bias
|
||||
self.is_cross_attention = cross_attention_dim is not None
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
self.residual_connection = residual_connection
|
||||
self.dropout = dropout
|
||||
self.fused_projections = False
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
self.pre_only = pre_only
|
||||
self.is_causal = is_causal
|
||||
|
||||
self.scale_qk = scale_qk
|
||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
# for slice_size > 0 the attention score computation
|
||||
# is split across the batch axis to save memory
|
||||
# You can set slice_size with `set_attention_slice`
|
||||
self.sliceable_head_dim = heads
|
||||
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
||||
raise ValueError(
|
||||
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
||||
)
|
||||
|
||||
self.group_norm = None
|
||||
self.spatial_norm = None
|
||||
|
||||
self.norm_q = None
|
||||
self.norm_k = None
|
||||
|
||||
self.norm_cross = None
|
||||
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
if not self.only_cross_attention:
|
||||
# only relevant for the `AddedKVProcessor` classes
|
||||
self.to_k = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_k = None
|
||||
self.to_v = None
|
||||
|
||||
self.added_proj_bias = added_proj_bias
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
||||
self.add_v_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
||||
if self.context_pre_only is not None:
|
||||
self.add_q_proj = operations.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.add_q_proj = None
|
||||
self.add_k_proj = None
|
||||
self.add_v_proj = None
|
||||
|
||||
if not self.pre_only:
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
else:
|
||||
self.to_out = None
|
||||
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_add_out = None
|
||||
|
||||
self.norm_added_q = None
|
||||
self.norm_added_k = None
|
||||
self.processor = processor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**cross_attention_kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CustomLiteLAProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
|
||||
|
||||
def __init__(self):
|
||||
self.kernel_func = nn.ReLU(inplace=False)
|
||||
self.eps = 1e-15
|
||||
self.pad_val = 1.0
|
||||
|
||||
def apply_rotary_emb(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states_len = hidden_states.shape[1]
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
if encoder_hidden_states is not None:
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
dtype = hidden_states.dtype
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# `context` projections.
|
||||
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
|
||||
if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
# attention
|
||||
if not attn.is_cross_attention:
|
||||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
||||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
||||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
||||
else:
|
||||
query = hidden_states
|
||||
key = encoder_hidden_states
|
||||
value = encoder_hidden_states
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
||||
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
|
||||
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
||||
|
||||
# RoPE需要 [B, H, S, D] 输入
|
||||
# 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
|
||||
query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
|
||||
|
||||
# Apply query and key normalization if needed
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if rotary_freqs_cis is not None:
|
||||
query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
||||
if not attn.is_cross_attention:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
||||
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
||||
|
||||
# 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
|
||||
query = query.permute(0, 1, 3, 2) # [B, H, D, S]
|
||||
|
||||
if attention_mask is not None:
|
||||
# attention_mask: [B, S] -> [B, 1, S, 1]
|
||||
attention_mask = attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S, 1]
|
||||
query = query * attention_mask.permute(0, 1, 3, 2) # [B, H, S, D] * [B, 1, S, 1]
|
||||
if not attn.is_cross_attention:
|
||||
key = key * attention_mask # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
|
||||
value = value * attention_mask.permute(0, 1, 3, 2) # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度
|
||||
|
||||
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
|
||||
encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S_enc, 1]
|
||||
# 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
|
||||
key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
|
||||
value = value * encoder_attention_mask.permute(0, 1, 3, 2) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
|
||||
|
||||
query = self.kernel_func(query)
|
||||
key = self.kernel_func(key)
|
||||
|
||||
query, key, value = query.float(), key.float(), value.float()
|
||||
|
||||
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
|
||||
|
||||
vk = torch.matmul(value, key)
|
||||
|
||||
hidden_states = torch.matmul(vk, query)
|
||||
|
||||
if hidden_states.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.float()
|
||||
|
||||
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
||||
|
||||
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
|
||||
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = encoder_hidden_states.to(dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : hidden_states_len],
|
||||
hidden_states[:, hidden_states_len:],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"):
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if encoder_hidden_states is not None and context_input_ndim == 4:
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if torch.get_autocast_gpu_dtype() == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CustomerAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def apply_rotary_emb(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
residual = hidden_states
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if rotary_freqs_cis is not None:
|
||||
query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
||||
if not attn.is_cross_attention:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
||||
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
||||
|
||||
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
|
||||
# attention_mask: N x S1
|
||||
# encoder_attention_mask: N x S2
|
||||
# cross attention 整合attention_mask和encoder_attention_mask
|
||||
combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
|
||||
attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
|
||||
attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype)
|
||||
|
||||
elif not attn.is_cross_attention and attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = optimized_attention(
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
||||
).to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
|
||||
"""Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
|
||||
if isinstance(x, (list, tuple)):
|
||||
return list(x)
|
||||
return [x for _ in range(repeat_time)]
|
||||
|
||||
|
||||
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
|
||||
"""Return tuple with min_len by repeating element at idx_repeat."""
|
||||
# convert to list first
|
||||
x = val2list(x)
|
||||
|
||||
# repeat elements if necessary
|
||||
if len(x) > 0:
|
||||
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
|
||||
|
||||
return tuple(x)
|
||||
|
||||
|
||||
def t2i_modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
|
||||
if isinstance(kernel_size, tuple):
|
||||
return tuple([get_same_padding(ks) for ks in kernel_size])
|
||||
else:
|
||||
assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
|
||||
return kernel_size // 2
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
padding: Union[int, None] = None,
|
||||
use_bias=False,
|
||||
norm=None,
|
||||
act=None,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
if padding is None:
|
||||
padding = get_same_padding(kernel_size)
|
||||
padding *= dilation
|
||||
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
self.padding = padding
|
||||
self.use_bias = use_bias
|
||||
|
||||
self.conv = operations.Conv1d(
|
||||
in_dim,
|
||||
out_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=use_bias,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
if norm is not None:
|
||||
self.norm = operations.RMSNorm(out_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
else:
|
||||
self.norm = None
|
||||
if act is not None:
|
||||
self.act = nn.SiLU(inplace=True)
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(x)
|
||||
if self.norm:
|
||||
x = self.norm(x)
|
||||
if self.act:
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class GLUMBConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
out_feature=None,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding: Union[int, None] = None,
|
||||
use_bias=False,
|
||||
norm=(None, None, None),
|
||||
act=("silu", "silu", None),
|
||||
dilation=1,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
out_feature = out_feature or in_features
|
||||
super().__init__()
|
||||
use_bias = val2tuple(use_bias, 3)
|
||||
norm = val2tuple(norm, 3)
|
||||
act = val2tuple(act, 3)
|
||||
|
||||
self.glu_act = nn.SiLU(inplace=False)
|
||||
self.inverted_conv = ConvLayer(
|
||||
in_features,
|
||||
hidden_features * 2,
|
||||
1,
|
||||
use_bias=use_bias[0],
|
||||
norm=norm[0],
|
||||
act=act[0],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.depth_conv = ConvLayer(
|
||||
hidden_features * 2,
|
||||
hidden_features * 2,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
groups=hidden_features * 2,
|
||||
padding=padding,
|
||||
use_bias=use_bias[1],
|
||||
norm=norm[1],
|
||||
act=None,
|
||||
dilation=dilation,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.point_conv = ConvLayer(
|
||||
hidden_features,
|
||||
out_feature,
|
||||
1,
|
||||
use_bias=use_bias[2],
|
||||
norm=norm[2],
|
||||
act=act[2],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.transpose(1, 2)
|
||||
x = self.inverted_conv(x)
|
||||
x = self.depth_conv(x)
|
||||
|
||||
x, gate = torch.chunk(x, 2, dim=1)
|
||||
gate = self.glu_act(gate)
|
||||
x = x * gate
|
||||
|
||||
x = self.point_conv(x)
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LinearTransformerBlock(nn.Module):
|
||||
"""
|
||||
A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
use_adaln_single=True,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=None,
|
||||
context_pre_only=False,
|
||||
mlp_ratio=4.0,
|
||||
add_cross_attention=False,
|
||||
add_cross_attention_dim=None,
|
||||
qk_norm=None,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = operations.RMSNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
qk_norm=qk_norm,
|
||||
processor=CustomLiteLAProcessor2_0(),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.add_cross_attention = add_cross_attention
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
if add_cross_attention and add_cross_attention_dim is not None:
|
||||
self.cross_attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=add_cross_attention_dim,
|
||||
added_kv_proj_dim=add_cross_attention_dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=context_pre_only,
|
||||
bias=True,
|
||||
qk_norm=qk_norm,
|
||||
processor=CustomerAttnProcessor2_0(),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.norm2 = operations.RMSNorm(dim, 1e-06, elementwise_affine=False)
|
||||
|
||||
self.ff = GLUMBConv(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
use_bias=(True, True, False),
|
||||
norm=(None, None, None),
|
||||
act=("silu", "silu", None),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.use_adaln_single = use_adaln_single
|
||||
if use_adaln_single:
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: torch.FloatTensor = None,
|
||||
encoder_attention_mask: torch.FloatTensor = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
temb: torch.FloatTensor = None,
|
||||
):
|
||||
|
||||
N = hidden_states.shape[0]
|
||||
|
||||
# step 1: AdaLN single
|
||||
if self.use_adaln_single:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
comfy.model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
if self.use_adaln_single:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
# step 2: attention
|
||||
if not self.add_cross_attention:
|
||||
attn_output, encoder_hidden_states = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
)
|
||||
else:
|
||||
attn_output, _ = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=None,
|
||||
)
|
||||
|
||||
if self.use_adaln_single:
|
||||
attn_output = gate_msa * attn_output
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
if self.add_cross_attention:
|
||||
attn_output = self.cross_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# step 3: add norm
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
if self.use_adaln_single:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
# step 4: feed forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
if self.use_adaln_single:
|
||||
ff_output = gate_mlp * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
|
||||
return hidden_states
|
||||
1067
comfy/ldm/ace/lyric_encoder.py
Normal file
1067
comfy/ldm/ace/lyric_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
385
comfy/ldm/ace/model.py
Normal file
385
comfy/ldm/ace/model.py
Normal file
@@ -0,0 +1,385 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/ace_step_transformer.py
|
||||
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, List, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from .attention import LinearTransformerBlock, t2i_modulate
|
||||
from .lyric_encoder import ConformerEncoder as LyricEncoder
|
||||
|
||||
|
||||
def cross_norm(hidden_states, controlnet_input):
|
||||
# input N x T x c
|
||||
mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
|
||||
mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
|
||||
controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
|
||||
return controlnet_input
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
|
||||
class Qwen2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, dtype=None, device=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
class T2IFinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of Sana.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(2, hidden_size, dtype=dtype, device=device))
|
||||
self.out_channels = out_channels
|
||||
self.patch_size = patch_size
|
||||
|
||||
def unpatchfy(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
width: int,
|
||||
):
|
||||
# 4 unpatchify
|
||||
new_height, new_width = 1, hidden_states.size(1)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
|
||||
).contiguous()
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
|
||||
).contiguous()
|
||||
if width > new_width:
|
||||
output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
|
||||
elif width < new_width:
|
||||
output = output[:, :, :, :width]
|
||||
return output
|
||||
|
||||
def forward(self, x, t, output_length):
|
||||
shift, scale = (comfy.model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
|
||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
# unpatchify
|
||||
output = self.unpatchfy(x, output_length)
|
||||
return output
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
height=16,
|
||||
width=4096,
|
||||
patch_size=(16, 1),
|
||||
in_channels=8,
|
||||
embed_dim=1152,
|
||||
bias=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
patch_size_h, patch_size_w = patch_size
|
||||
self.early_conv_layers = nn.Sequential(
|
||||
operations.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias, dtype=dtype, device=device),
|
||||
operations.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True, dtype=dtype, device=device),
|
||||
operations.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias, dtype=dtype, device=device)
|
||||
)
|
||||
self.patch_size = patch_size
|
||||
self.height, self.width = height // patch_size_h, width // patch_size_w
|
||||
self.base_size = self.width
|
||||
|
||||
def forward(self, latent):
|
||||
# early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
|
||||
latent = self.early_conv_layers(latent)
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
return latent
|
||||
|
||||
|
||||
class ACEStepTransformer2DModel(nn.Module):
|
||||
# _supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: Optional[int] = 8,
|
||||
num_layers: int = 28,
|
||||
inner_dim: int = 1536,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 24,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_channels: int = 8,
|
||||
max_position: int = 32768,
|
||||
rope_theta: float = 1000000.0,
|
||||
speaker_embedding_dim: int = 512,
|
||||
text_embedding_dim: int = 768,
|
||||
ssl_encoder_depths: List[int] = [9, 9],
|
||||
ssl_names: List[str] = ["mert", "m-hubert"],
|
||||
ssl_latent_dims: List[int] = [1024, 768],
|
||||
lyric_encoder_vocab_size: int = 6681,
|
||||
lyric_hidden_size: int = 1024,
|
||||
patch_size: List[int] = [16, 1],
|
||||
max_height: int = 16,
|
||||
max_width: int = 4096,
|
||||
audio_model=None,
|
||||
dtype=None, device=None, operations=None
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dtype = dtype
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
self.out_channels = out_channels
|
||||
self.max_position = max_position
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.rotary_emb = Qwen2RotaryEmbedding(
|
||||
dim=self.attention_head_dim,
|
||||
max_position_embeddings=self.max_position,
|
||||
base=self.rope_theta,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.num_layers = num_layers
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
LinearTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
add_cross_attention=True,
|
||||
add_cross_attention_dim=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.t_block = nn.Sequential(nn.SiLU(), operations.Linear(self.inner_dim, 6 * self.inner_dim, bias=True, dtype=dtype, device=device))
|
||||
|
||||
# speaker
|
||||
self.speaker_embedder = operations.Linear(speaker_embedding_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
# genre
|
||||
self.genre_embedder = operations.Linear(text_embedding_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
# lyric
|
||||
self.lyric_embs = operations.Embedding(lyric_encoder_vocab_size, lyric_hidden_size, dtype=dtype, device=device)
|
||||
self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0, dtype=dtype, device=device, operations=operations)
|
||||
self.lyric_proj = operations.Linear(lyric_hidden_size, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
projector_dim = 2 * self.inner_dim
|
||||
|
||||
self.projectors = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
operations.Linear(self.inner_dim, projector_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(projector_dim, projector_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(projector_dim, ssl_dim, dtype=dtype, device=device),
|
||||
) for ssl_dim in ssl_latent_dims
|
||||
])
|
||||
|
||||
self.proj_in = PatchEmbed(
|
||||
height=max_height,
|
||||
width=max_width,
|
||||
patch_size=patch_size,
|
||||
embed_dim=self.inner_dim,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_lyric_encoder(
|
||||
self,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
out_dtype=None,
|
||||
):
|
||||
# N x T x D
|
||||
lyric_embs = self.lyric_embs(lyric_token_idx, out_dtype=out_dtype)
|
||||
prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
|
||||
prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
|
||||
return prompt_prenet_out
|
||||
|
||||
def encode(
|
||||
self,
|
||||
encoder_text_hidden_states: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
lyrics_strength=1.0,
|
||||
):
|
||||
|
||||
bs = encoder_text_hidden_states.shape[0]
|
||||
device = encoder_text_hidden_states.device
|
||||
|
||||
# speaker embedding
|
||||
encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
|
||||
|
||||
# genre embedding
|
||||
encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
|
||||
|
||||
# lyric
|
||||
encoder_lyric_hidden_states = self.forward_lyric_encoder(
|
||||
lyric_token_idx=lyric_token_idx,
|
||||
lyric_mask=lyric_mask,
|
||||
out_dtype=encoder_text_hidden_states.dtype,
|
||||
)
|
||||
|
||||
encoder_lyric_hidden_states *= lyrics_strength
|
||||
|
||||
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
|
||||
|
||||
encoder_hidden_mask = None
|
||||
if text_attention_mask is not None:
|
||||
speaker_mask = torch.ones(bs, 1, device=device)
|
||||
encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
|
||||
|
||||
return encoder_hidden_states, encoder_hidden_mask
|
||||
|
||||
def decode(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_mask: torch.Tensor,
|
||||
timestep: Optional[torch.Tensor],
|
||||
output_length: int = 0,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
):
|
||||
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
||||
temb = self.t_block(embedded_timestep)
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# controlnet logic
|
||||
if block_controlnet_hidden_states is not None:
|
||||
control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
|
||||
hidden_states = hidden_states + control_condi * controlnet_scale
|
||||
|
||||
# inner_hidden_states = []
|
||||
|
||||
rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
|
||||
encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_hidden_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||
temb=temb,
|
||||
)
|
||||
|
||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
attention_mask=None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
lyrics_strength=1.0,
|
||||
**kwargs
|
||||
):
|
||||
hidden_states = x
|
||||
encoder_text_hidden_states = context
|
||||
encoder_hidden_states, encoder_hidden_mask = self.encode(
|
||||
encoder_text_hidden_states=encoder_text_hidden_states,
|
||||
text_attention_mask=text_attention_mask,
|
||||
speaker_embeds=speaker_embeds,
|
||||
lyric_token_idx=lyric_token_idx,
|
||||
lyric_mask=lyric_mask,
|
||||
lyrics_strength=lyrics_strength,
|
||||
)
|
||||
|
||||
output_length = hidden_states.shape[-1]
|
||||
|
||||
output = self.decode(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_mask=encoder_hidden_mask,
|
||||
timestep=timestep,
|
||||
output_length=output_length,
|
||||
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
||||
controlnet_scale=controlnet_scale,
|
||||
)
|
||||
|
||||
return output
|
||||
644
comfy/ldm/ace/vae/autoencoder_dc.py
Normal file
644
comfy/ldm/ace/vae/autoencoder_dc.py
Normal file
@@ -0,0 +1,644 @@
|
||||
# Rewritten from diffusers
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple, Union
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class RMSNorm(ops.RMSNorm):
|
||||
def __init__(self, dim, eps=1e-5, elementwise_affine=True, bias=False):
|
||||
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
if elementwise_affine:
|
||||
self.bias = nn.Parameter(torch.empty(dim)) if bias else None
|
||||
|
||||
def forward(self, x):
|
||||
x = super().forward(x)
|
||||
if self.elementwise_affine:
|
||||
if self.bias is not None:
|
||||
x = x + comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device)
|
||||
return x
|
||||
|
||||
|
||||
def get_normalization(norm_type, num_features, num_groups=32, eps=1e-5):
|
||||
if norm_type == "batch_norm":
|
||||
return nn.BatchNorm2d(num_features)
|
||||
elif norm_type == "group_norm":
|
||||
return ops.GroupNorm(num_groups, num_features)
|
||||
elif norm_type == "layer_norm":
|
||||
return ops.LayerNorm(num_features)
|
||||
elif norm_type == "rms_norm":
|
||||
return RMSNorm(num_features, eps=eps, elementwise_affine=True, bias=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown normalization type: {norm_type}")
|
||||
|
||||
|
||||
def get_activation(activation_type):
|
||||
if activation_type == "relu":
|
||||
return nn.ReLU()
|
||||
elif activation_type == "relu6":
|
||||
return nn.ReLU6()
|
||||
elif activation_type == "silu":
|
||||
return nn.SiLU()
|
||||
elif activation_type == "leaky_relu":
|
||||
return nn.LeakyReLU(0.2)
|
||||
else:
|
||||
raise ValueError(f"Unknown activation type: {activation_type}")
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
norm_type: str = "batch_norm",
|
||||
act_fn: str = "relu6",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
|
||||
self.conv1 = ops.Conv2d(in_channels, in_channels, 3, 1, 1)
|
||||
self.conv2 = ops.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
|
||||
self.norm = get_normalization(norm_type, out_channels)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
else:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states + residual
|
||||
|
||||
class SanaMultiscaleAttentionProjection(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_attention_heads: int,
|
||||
kernel_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
channels = 3 * in_channels
|
||||
self.proj_in = ops.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=channels,
|
||||
bias=False,
|
||||
)
|
||||
self.proj_out = ops.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class SanaMultiscaleLinearAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_attention_heads: int = None,
|
||||
attention_head_dim: int = 8,
|
||||
mult: float = 1.0,
|
||||
norm_type: str = "batch_norm",
|
||||
kernel_sizes: tuple = (5,),
|
||||
eps: float = 1e-15,
|
||||
residual_connection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
num_attention_heads = (
|
||||
int(in_channels // attention_head_dim * mult)
|
||||
if num_attention_heads is None
|
||||
else num_attention_heads
|
||||
)
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.to_q = ops.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_k = ops.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_v = ops.Linear(in_channels, inner_dim, bias=False)
|
||||
|
||||
self.to_qkv_multiscale = nn.ModuleList()
|
||||
for kernel_size in kernel_sizes:
|
||||
self.to_qkv_multiscale.append(
|
||||
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
|
||||
)
|
||||
|
||||
self.nonlinearity = nn.ReLU()
|
||||
self.to_out = ops.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
|
||||
self.norm_out = get_normalization(norm_type, out_channels)
|
||||
|
||||
def apply_linear_attention(self, query, key, value):
|
||||
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1)
|
||||
scores = torch.matmul(value, key.transpose(-1, -2))
|
||||
hidden_states = torch.matmul(scores, query)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
||||
return hidden_states
|
||||
|
||||
def apply_quadratic_attention(self, query, key, value):
|
||||
scores = torch.matmul(key.transpose(-1, -2), query)
|
||||
scores = scores.to(dtype=torch.float32)
|
||||
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
|
||||
hidden_states = torch.matmul(value, scores.to(value.dtype))
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states):
|
||||
height, width = hidden_states.shape[-2:]
|
||||
if height * width > self.attention_head_dim:
|
||||
use_linear_attention = True
|
||||
else:
|
||||
use_linear_attention = False
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
batch_size, _, height, width = list(hidden_states.size())
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
hidden_states = hidden_states.movedim(1, -1)
|
||||
query = self.to_q(hidden_states)
|
||||
key = self.to_k(hidden_states)
|
||||
value = self.to_v(hidden_states)
|
||||
hidden_states = torch.cat([query, key, value], dim=3)
|
||||
hidden_states = hidden_states.movedim(-1, 1)
|
||||
|
||||
multi_scale_qkv = [hidden_states]
|
||||
for block in self.to_qkv_multiscale:
|
||||
multi_scale_qkv.append(block(hidden_states))
|
||||
|
||||
hidden_states = torch.cat(multi_scale_qkv, dim=1)
|
||||
|
||||
if use_linear_attention:
|
||||
# for linear attention upcast hidden_states to float32
|
||||
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch_size, -1, 3 * self.attention_head_dim, height * width)
|
||||
|
||||
query, key, value = hidden_states.chunk(3, dim=2)
|
||||
query = self.nonlinearity(query)
|
||||
key = self.nonlinearity(key)
|
||||
|
||||
if use_linear_attention:
|
||||
hidden_states = self.apply_linear_attention(query, key, value)
|
||||
hidden_states = hidden_states.to(dtype=original_dtype)
|
||||
else:
|
||||
hidden_states = self.apply_quadratic_attention(query, key, value)
|
||||
|
||||
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
|
||||
hidden_states = self.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
else:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
if self.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EfficientViTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
mult: float = 1.0,
|
||||
attention_head_dim: int = 32,
|
||||
qkv_multiscales: tuple = (5,),
|
||||
norm_type: str = "batch_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attn = SanaMultiscaleLinearAttention(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
mult=mult,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type,
|
||||
kernel_sizes=qkv_multiscales,
|
||||
residual_connection=True,
|
||||
)
|
||||
|
||||
self.conv_out = GLUMBConv(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
norm_type="rms_norm",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.attn(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class GLUMBConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
expand_ratio: float = 4,
|
||||
norm_type: str = None,
|
||||
residual_connection: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_channels = int(expand_ratio * in_channels)
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
self.nonlinearity = nn.SiLU()
|
||||
self.conv_inverted = ops.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
|
||||
self.conv_depth = ops.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
|
||||
self.conv_point = ops.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
|
||||
|
||||
self.norm = None
|
||||
if norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.residual_connection:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.conv_inverted(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv_depth(hidden_states)
|
||||
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
|
||||
hidden_states = hidden_states * self.nonlinearity(gate)
|
||||
|
||||
hidden_states = self.conv_point(hidden_states)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_block(
|
||||
block_type: str,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
attention_head_dim: int,
|
||||
norm_type: str,
|
||||
act_fn: str,
|
||||
qkv_mutliscales: tuple = (),
|
||||
):
|
||||
if block_type == "ResBlock":
|
||||
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||
elif block_type == "EfficientViTBlock":
|
||||
block = EfficientViTBlock(
|
||||
in_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type,
|
||||
qkv_multiscales=qkv_mutliscales
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Block with {block_type=} is not supported.")
|
||||
|
||||
return block
|
||||
|
||||
|
||||
class DCDownBlock2d(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.downsample = downsample
|
||||
self.factor = 2
|
||||
self.stride = 1 if downsample else 2
|
||||
self.group_size = in_channels * self.factor**2 // out_channels
|
||||
self.shortcut = shortcut
|
||||
|
||||
out_ratio = self.factor**2
|
||||
if downsample:
|
||||
assert out_channels % out_ratio == 0
|
||||
out_channels = out_channels // out_ratio
|
||||
|
||||
self.conv = ops.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=self.stride,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(hidden_states)
|
||||
if self.downsample:
|
||||
x = F.pixel_unshuffle(x, self.factor)
|
||||
|
||||
if self.shortcut:
|
||||
y = F.pixel_unshuffle(hidden_states, self.factor)
|
||||
y = y.unflatten(1, (-1, self.group_size))
|
||||
y = y.mean(dim=2)
|
||||
hidden_states = x + y
|
||||
else:
|
||||
hidden_states = x
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DCUpBlock2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
interpolate: bool = False,
|
||||
shortcut: bool = True,
|
||||
interpolation_mode: str = "nearest",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.interpolate = interpolate
|
||||
self.interpolation_mode = interpolation_mode
|
||||
self.shortcut = shortcut
|
||||
self.factor = 2
|
||||
self.repeats = out_channels * self.factor**2 // in_channels
|
||||
|
||||
out_ratio = self.factor**2
|
||||
if not interpolate:
|
||||
out_channels = out_channels * out_ratio
|
||||
|
||||
self.conv = ops.Conv2d(in_channels, out_channels, 3, 1, 1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.interpolate:
|
||||
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = self.conv(hidden_states)
|
||||
x = F.pixel_shuffle(x, self.factor)
|
||||
|
||||
if self.shortcut:
|
||||
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
|
||||
y = F.pixel_shuffle(y, self.factor)
|
||||
hidden_states = x + y
|
||||
else:
|
||||
hidden_states = x
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: str or tuple = "ResBlock",
|
||||
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
out_shortcut: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_blocks = len(block_out_channels)
|
||||
|
||||
if isinstance(block_type, str):
|
||||
block_type = (block_type,) * num_blocks
|
||||
|
||||
if layers_per_block[0] > 0:
|
||||
self.conv_in = ops.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
else:
|
||||
self.conv_in = DCDownBlock2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
|
||||
downsample=downsample_block_type == "pixel_unshuffle",
|
||||
shortcut=False,
|
||||
)
|
||||
|
||||
down_blocks = []
|
||||
for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)):
|
||||
down_block_list = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
block = get_block(
|
||||
block_type[i],
|
||||
out_channel,
|
||||
out_channel,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type="rms_norm",
|
||||
act_fn="silu",
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
)
|
||||
down_block_list.append(block)
|
||||
|
||||
if i < num_blocks - 1 and num_layers > 0:
|
||||
downsample_block = DCDownBlock2d(
|
||||
in_channels=out_channel,
|
||||
out_channels=block_out_channels[i + 1],
|
||||
downsample=downsample_block_type == "pixel_unshuffle",
|
||||
shortcut=True,
|
||||
)
|
||||
down_block_list.append(downsample_block)
|
||||
|
||||
down_blocks.append(nn.Sequential(*down_block_list))
|
||||
|
||||
self.down_blocks = nn.ModuleList(down_blocks)
|
||||
|
||||
self.conv_out = ops.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1)
|
||||
|
||||
self.out_shortcut = out_shortcut
|
||||
if out_shortcut:
|
||||
self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
|
||||
if self.out_shortcut:
|
||||
x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size))
|
||||
x = x.mean(dim=2)
|
||||
hidden_states = self.conv_out(hidden_states) + x
|
||||
else:
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: str or tuple = "ResBlock",
|
||||
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
|
||||
norm_type: str or tuple = "rms_norm",
|
||||
act_fn: str or tuple = "silu",
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
in_shortcut: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_blocks = len(block_out_channels)
|
||||
|
||||
if isinstance(block_type, str):
|
||||
block_type = (block_type,) * num_blocks
|
||||
if isinstance(norm_type, str):
|
||||
norm_type = (norm_type,) * num_blocks
|
||||
if isinstance(act_fn, str):
|
||||
act_fn = (act_fn,) * num_blocks
|
||||
|
||||
self.conv_in = ops.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1)
|
||||
|
||||
self.in_shortcut = in_shortcut
|
||||
if in_shortcut:
|
||||
self.in_shortcut_repeats = block_out_channels[-1] // latent_channels
|
||||
|
||||
up_blocks = []
|
||||
for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
|
||||
up_block_list = []
|
||||
|
||||
if i < num_blocks - 1 and num_layers > 0:
|
||||
upsample_block = DCUpBlock2d(
|
||||
block_out_channels[i + 1],
|
||||
out_channel,
|
||||
interpolate=upsample_block_type == "interpolate",
|
||||
shortcut=True,
|
||||
)
|
||||
up_block_list.append(upsample_block)
|
||||
|
||||
for _ in range(num_layers):
|
||||
block = get_block(
|
||||
block_type[i],
|
||||
out_channel,
|
||||
out_channel,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type[i],
|
||||
act_fn=act_fn[i],
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
)
|
||||
up_block_list.append(block)
|
||||
|
||||
up_blocks.insert(0, nn.Sequential(*up_block_list))
|
||||
|
||||
self.up_blocks = nn.ModuleList(up_blocks)
|
||||
|
||||
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
|
||||
|
||||
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
|
||||
self.conv_act = nn.ReLU()
|
||||
self.conv_out = None
|
||||
|
||||
if layers_per_block[0] > 0:
|
||||
self.conv_out = ops.Conv2d(channels, in_channels, 3, 1, 1)
|
||||
else:
|
||||
self.conv_out = DCUpBlock2d(
|
||||
channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.in_shortcut:
|
||||
x = hidden_states.repeat_interleave(
|
||||
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
|
||||
)
|
||||
hidden_states = self.conv_in(hidden_states) + x
|
||||
else:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for up_block in reversed(self.up_blocks):
|
||||
hidden_states = up_block(hidden_states)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderDC(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 2,
|
||||
latent_channels: int = 8,
|
||||
attention_head_dim: int = 32,
|
||||
encoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
|
||||
decoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
|
||||
encoder_layers_per_block: Tuple[int] = (2, 2, 3, 3),
|
||||
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3),
|
||||
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
|
||||
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
|
||||
upsample_block_type: str = "interpolate",
|
||||
downsample_block_type: str = "Conv",
|
||||
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
|
||||
decoder_act_fns: Union[str, Tuple[str]] = "silu",
|
||||
scaling_factor: float = 0.41407,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
latent_channels=latent_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
block_type=encoder_block_types,
|
||||
block_out_channels=encoder_block_out_channels,
|
||||
layers_per_block=encoder_layers_per_block,
|
||||
qkv_multiscales=encoder_qkv_multiscales,
|
||||
downsample_block_type=downsample_block_type,
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
in_channels=in_channels,
|
||||
latent_channels=latent_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
block_type=decoder_block_types,
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
layers_per_block=decoder_layers_per_block,
|
||||
qkv_multiscales=decoder_qkv_multiscales,
|
||||
norm_type=decoder_norm_types,
|
||||
act_fn=decoder_act_fns,
|
||||
upsample_block_type=upsample_block_type,
|
||||
)
|
||||
|
||||
self.scaling_factor = scaling_factor
|
||||
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Internal encoding function."""
|
||||
encoded = self.encoder(x)
|
||||
return encoded * self.scaling_factor
|
||||
|
||||
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
# Scale the latents back
|
||||
z = z / self.scaling_factor
|
||||
decoded = self.decoder(z)
|
||||
return decoded
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
z = self.encode(x)
|
||||
return self.decode(z)
|
||||
|
||||
109
comfy/ldm/ace/vae/music_dcae_pipeline.py
Normal file
109
comfy/ldm/ace/vae/music_dcae_pipeline.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py
|
||||
import torch
|
||||
from .autoencoder_dc import AutoencoderDC
|
||||
import logging
|
||||
try:
|
||||
import torchaudio
|
||||
except:
|
||||
logging.warning("torchaudio missing, ACE model will be broken")
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
from .music_vocoder import ADaMoSHiFiGANV1
|
||||
|
||||
|
||||
class MusicDCAE(torch.nn.Module):
|
||||
def __init__(self, source_sample_rate=None, dcae_config={}, vocoder_config={}):
|
||||
super(MusicDCAE, self).__init__()
|
||||
|
||||
self.dcae = AutoencoderDC(**dcae_config)
|
||||
self.vocoder = ADaMoSHiFiGANV1(**vocoder_config)
|
||||
|
||||
if source_sample_rate is None:
|
||||
self.source_sample_rate = 48000
|
||||
else:
|
||||
self.source_sample_rate = source_sample_rate
|
||||
|
||||
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Normalize(0.5, 0.5),
|
||||
])
|
||||
self.min_mel_value = -11.0
|
||||
self.max_mel_value = 3.0
|
||||
self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
|
||||
self.mel_chunk_size = 1024
|
||||
self.time_dimention_multiple = 8
|
||||
self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
|
||||
self.scale_factor = 0.1786
|
||||
self.shift_factor = -1.9091
|
||||
|
||||
def load_audio(self, audio_path):
|
||||
audio, sr = torchaudio.load(audio_path)
|
||||
return audio, sr
|
||||
|
||||
def forward_mel(self, audios):
|
||||
mels = []
|
||||
for i in range(len(audios)):
|
||||
image = self.vocoder.mel_transform(audios[i])
|
||||
mels.append(image)
|
||||
mels = torch.stack(mels)
|
||||
return mels
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, audios, audio_lengths=None, sr=None):
|
||||
if audio_lengths is None:
|
||||
audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
|
||||
audio_lengths = audio_lengths.to(audios.device)
|
||||
|
||||
if sr is None:
|
||||
sr = self.source_sample_rate
|
||||
|
||||
if sr != 44100:
|
||||
audios = torchaudio.functional.resample(audios, sr, 44100)
|
||||
|
||||
max_audio_len = audios.shape[-1]
|
||||
if max_audio_len % (8 * 512) != 0:
|
||||
audios = torch.nn.functional.pad(audios, (0, 8 * 512 - max_audio_len % (8 * 512)))
|
||||
|
||||
mels = self.forward_mel(audios)
|
||||
mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
|
||||
mels = self.transform(mels)
|
||||
latents = []
|
||||
for mel in mels:
|
||||
latent = self.dcae.encoder(mel.unsqueeze(0))
|
||||
latents.append(latent)
|
||||
latents = torch.cat(latents, dim=0)
|
||||
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
|
||||
latents = (latents - self.shift_factor) * self.scale_factor
|
||||
return latents
|
||||
# return latents, latent_lengths
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, latents, audio_lengths=None, sr=None):
|
||||
latents = latents / self.scale_factor + self.shift_factor
|
||||
|
||||
pred_wavs = []
|
||||
|
||||
for latent in latents:
|
||||
mels = self.dcae.decoder(latent.unsqueeze(0))
|
||||
mels = mels * 0.5 + 0.5
|
||||
mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
|
||||
wav = self.vocoder.decode(mels[0]).squeeze(1)
|
||||
|
||||
if sr is not None:
|
||||
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
|
||||
wav = torchaudio.functional.resample(wav, 44100, sr)
|
||||
# wav = resampler(wav)
|
||||
else:
|
||||
sr = 44100
|
||||
pred_wavs.append(wav)
|
||||
|
||||
if audio_lengths is not None:
|
||||
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
|
||||
return torch.stack(pred_wavs)
|
||||
# return sr, pred_wavs
|
||||
|
||||
def forward(self, audios, audio_lengths=None, sr=None):
|
||||
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
|
||||
sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
|
||||
return sr, pred_wavs, latents, latent_lengths
|
||||
113
comfy/ldm/ace/vae/music_log_mel.py
Executable file
113
comfy/ldm/ace/vae/music_log_mel.py
Executable file
@@ -0,0 +1,113 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
import logging
|
||||
try:
|
||||
from torchaudio.transforms import MelScale
|
||||
except:
|
||||
logging.warning("torchaudio missing, ACE model will be broken")
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
class LinearSpectrogram(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
center=False,
|
||||
mode="pow2_sqrt",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.mode = mode
|
||||
|
||||
self.register_buffer("window", torch.hann_window(win_length))
|
||||
|
||||
def forward(self, y: Tensor) -> Tensor:
|
||||
if y.ndim == 3:
|
||||
y = y.squeeze(1)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(
|
||||
(self.win_length - self.hop_length) // 2,
|
||||
(self.win_length - self.hop_length + 1) // 2,
|
||||
),
|
||||
mode="reflect",
|
||||
).squeeze(1)
|
||||
dtype = y.dtype
|
||||
spec = torch.stft(
|
||||
y.float(),
|
||||
self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=comfy.model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
|
||||
center=self.center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
spec = torch.view_as_real(spec)
|
||||
|
||||
if self.mode == "pow2_sqrt":
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
spec = spec.to(dtype)
|
||||
return spec
|
||||
|
||||
|
||||
class LogMelSpectrogram(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=44100,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
n_mels=128,
|
||||
center=False,
|
||||
f_min=0.0,
|
||||
f_max=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.n_mels = n_mels
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max or sample_rate // 2
|
||||
|
||||
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
|
||||
self.mel_scale = MelScale(
|
||||
self.n_mels,
|
||||
self.sample_rate,
|
||||
self.f_min,
|
||||
self.f_max,
|
||||
self.n_fft // 2 + 1,
|
||||
"slaney",
|
||||
"slaney",
|
||||
)
|
||||
|
||||
def compress(self, x: Tensor) -> Tensor:
|
||||
return torch.log(torch.clamp(x, min=1e-5))
|
||||
|
||||
def decompress(self, x: Tensor) -> Tensor:
|
||||
return torch.exp(x)
|
||||
|
||||
def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
|
||||
linear = self.spectrogram(x)
|
||||
x = self.mel_scale(linear)
|
||||
x = self.compress(x)
|
||||
# print(x.shape)
|
||||
if return_linear:
|
||||
return x, self.compress(linear)
|
||||
|
||||
return x
|
||||
538
comfy/ldm/ace/vae/music_vocoder.py
Executable file
538
comfy/ldm/ace/vae/music_vocoder.py
Executable file
@@ -0,0 +1,538 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_vocoder.py
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from functools import partial
|
||||
from math import prod
|
||||
from typing import Callable, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
|
||||
|
||||
from .music_log_mel import LogMelSpectrogram
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
def drop_path(
|
||||
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
||||
):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (
|
||||
x.ndim - 1
|
||||
) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0 and scale_by_keep:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
|
||||
|
||||
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.scale_by_keep = scale_by_keep
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
||||
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
||||
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
||||
with shape (batch_size, channels, height, width).
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||
self.eps = eps
|
||||
self.data_format = data_format
|
||||
if self.data_format not in ["channels_last", "channels_first"]:
|
||||
raise NotImplementedError
|
||||
self.normalized_shape = (normalized_shape,)
|
||||
|
||||
def forward(self, x):
|
||||
if self.data_format == "channels_last":
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape, comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device), comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device), self.eps
|
||||
)
|
||||
elif self.data_format == "channels_first":
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = comfy.model_management.cast_to(self.weight[:, None], dtype=x.dtype, device=x.device) * x + comfy.model_management.cast_to(self.bias[:, None], dtype=x.dtype, device=x.device)
|
||||
return x
|
||||
|
||||
|
||||
class ConvNeXtBlock(nn.Module):
|
||||
r"""ConvNeXt Block. There are two equivalent implementations:
|
||||
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
||||
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
||||
We use (2) as we find it slightly faster in PyTorch
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
drop_path (float): Stochastic depth rate. Default: 0.0
|
||||
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
|
||||
kernel_size (int): Kernel size for depthwise conv. Default: 7.
|
||||
dilation (int): Dilation for depthwise conv. Default: 1.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
drop_path: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-6,
|
||||
mlp_ratio: float = 4.0,
|
||||
kernel_size: int = 7,
|
||||
dilation: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dwconv = ops.Conv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=int(dilation * (kernel_size - 1) / 2),
|
||||
groups=dim,
|
||||
) # depthwise conv
|
||||
self.norm = LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = ops.Linear(
|
||||
dim, int(mlp_ratio * dim)
|
||||
) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = ops.Linear(int(mlp_ratio * dim), dim)
|
||||
self.gamma = (
|
||||
nn.Parameter(torch.empty((dim)), requires_grad=False)
|
||||
if layer_scale_init_value > 0
|
||||
else None
|
||||
)
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
def forward(self, x, apply_residual: bool = True):
|
||||
input = x
|
||||
|
||||
x = self.dwconv(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pwconv2(x)
|
||||
|
||||
if self.gamma is not None:
|
||||
x = comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) * x
|
||||
|
||||
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
||||
x = self.drop_path(x)
|
||||
|
||||
if apply_residual:
|
||||
x = input + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ParallelConvNeXtBlock(nn.Module):
|
||||
def __init__(self, kernel_sizes: List[int], *args, **kwargs):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
|
||||
for kernel_size in kernel_sizes
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack(
|
||||
[block(x, apply_residual=False) for block in self.blocks] + [x],
|
||||
dim=1,
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
class ConvNeXtEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channels=3,
|
||||
depths=[3, 3, 9, 3],
|
||||
dims=[96, 192, 384, 768],
|
||||
drop_path_rate=0.0,
|
||||
layer_scale_init_value=1e-6,
|
||||
kernel_sizes: Tuple[int] = (7,),
|
||||
):
|
||||
super().__init__()
|
||||
assert len(depths) == len(dims)
|
||||
|
||||
self.channel_layers = nn.ModuleList()
|
||||
stem = nn.Sequential(
|
||||
ops.Conv1d(
|
||||
input_channels,
|
||||
dims[0],
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
padding_mode="replicate",
|
||||
),
|
||||
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
||||
)
|
||||
self.channel_layers.append(stem)
|
||||
|
||||
for i in range(len(depths) - 1):
|
||||
mid_layer = nn.Sequential(
|
||||
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
||||
ops.Conv1d(dims[i], dims[i + 1], kernel_size=1),
|
||||
)
|
||||
self.channel_layers.append(mid_layer)
|
||||
|
||||
block_fn = (
|
||||
partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
|
||||
if len(kernel_sizes) == 1
|
||||
else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
|
||||
)
|
||||
|
||||
self.stages = nn.ModuleList()
|
||||
drop_path_rates = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
]
|
||||
|
||||
cur = 0
|
||||
for i in range(len(depths)):
|
||||
stage = nn.Sequential(
|
||||
*[
|
||||
block_fn(
|
||||
dim=dims[i],
|
||||
drop_path=drop_path_rates[cur + j],
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
)
|
||||
for j in range(depths[i])
|
||||
]
|
||||
)
|
||||
self.stages.append(stage)
|
||||
cur += depths[i]
|
||||
|
||||
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
for channel_layer, stage in zip(self.channel_layers, self.stages):
|
||||
x = channel_layer(x)
|
||||
x = stage(x)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return (kernel_size * dilation - dilation) // 2
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super().__init__()
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.silu(x)
|
||||
xt = c1(xt)
|
||||
xt = F.silu(xt)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for conv in self.convs1:
|
||||
remove_weight_norm(conv)
|
||||
for conv in self.convs2:
|
||||
remove_weight_norm(conv)
|
||||
|
||||
|
||||
class HiFiGANGenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
hop_length: int = 512,
|
||||
upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
|
||||
upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
|
||||
resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
|
||||
resblock_dilation_sizes: Tuple[Tuple[int]] = (
|
||||
(1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
||||
num_mels: int = 128,
|
||||
upsample_initial_channel: int = 512,
|
||||
use_template: bool = True,
|
||||
pre_conv_kernel_size: int = 7,
|
||||
post_conv_kernel_size: int = 7,
|
||||
post_activation: Callable = partial(nn.SiLU, inplace=True),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
prod(upsample_rates) == hop_length
|
||||
), f"hop_length must be {prod(upsample_rates)}"
|
||||
|
||||
self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
num_mels,
|
||||
upsample_initial_channel,
|
||||
pre_conv_kernel_size,
|
||||
1,
|
||||
padding=get_padding(pre_conv_kernel_size),
|
||||
)
|
||||
)
|
||||
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
|
||||
self.noise_convs = nn.ModuleList()
|
||||
self.use_template = use_template
|
||||
self.ups = nn.ModuleList()
|
||||
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups.append(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not use_template:
|
||||
continue
|
||||
|
||||
if i + 1 < len(upsample_rates):
|
||||
stride_f0 = np.prod(upsample_rates[i + 1:])
|
||||
self.noise_convs.append(
|
||||
ops.Conv1d(
|
||||
1,
|
||||
c_cur,
|
||||
kernel_size=stride_f0 * 2,
|
||||
stride=stride_f0,
|
||||
padding=stride_f0 // 2,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.noise_convs.append(ops.Conv1d(1, c_cur, kernel_size=1))
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||
self.resblocks.append(ResBlock1(ch, k, d))
|
||||
|
||||
self.activation_post = post_activation()
|
||||
self.conv_post = torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
ch,
|
||||
1,
|
||||
post_conv_kernel_size,
|
||||
1,
|
||||
padding=get_padding(post_conv_kernel_size),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x, template=None):
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.silu(x, inplace=True)
|
||||
x = self.ups[i](x)
|
||||
|
||||
if self.use_template:
|
||||
x = x + self.noise_convs[i](template)
|
||||
|
||||
xs = None
|
||||
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
|
||||
x = xs / self.num_kernels
|
||||
|
||||
x = self.activation_post(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for up in self.ups:
|
||||
remove_weight_norm(up)
|
||||
for block in self.resblocks:
|
||||
block.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
|
||||
class ADaMoSHiFiGANV1(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channels: int = 128,
|
||||
depths: List[int] = [3, 3, 9, 3],
|
||||
dims: List[int] = [128, 256, 384, 512],
|
||||
drop_path_rate: float = 0.0,
|
||||
kernel_sizes: Tuple[int] = (7,),
|
||||
upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
|
||||
upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
|
||||
resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
|
||||
resblock_dilation_sizes: Tuple[Tuple[int]] = (
|
||||
(1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
||||
num_mels: int = 512,
|
||||
upsample_initial_channel: int = 1024,
|
||||
use_template: bool = False,
|
||||
pre_conv_kernel_size: int = 13,
|
||||
post_conv_kernel_size: int = 13,
|
||||
sampling_rate: int = 44100,
|
||||
n_fft: int = 2048,
|
||||
win_length: int = 2048,
|
||||
hop_length: int = 512,
|
||||
f_min: int = 40,
|
||||
f_max: int = 16000,
|
||||
n_mels: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.backbone = ConvNeXtEncoder(
|
||||
input_channels=input_channels,
|
||||
depths=depths,
|
||||
dims=dims,
|
||||
drop_path_rate=drop_path_rate,
|
||||
kernel_sizes=kernel_sizes,
|
||||
)
|
||||
|
||||
self.head = HiFiGANGenerator(
|
||||
hop_length=hop_length,
|
||||
upsample_rates=upsample_rates,
|
||||
upsample_kernel_sizes=upsample_kernel_sizes,
|
||||
resblock_kernel_sizes=resblock_kernel_sizes,
|
||||
resblock_dilation_sizes=resblock_dilation_sizes,
|
||||
num_mels=num_mels,
|
||||
upsample_initial_channel=upsample_initial_channel,
|
||||
use_template=use_template,
|
||||
pre_conv_kernel_size=pre_conv_kernel_size,
|
||||
post_conv_kernel_size=post_conv_kernel_size,
|
||||
)
|
||||
self.sampling_rate = sampling_rate
|
||||
self.mel_transform = LogMelSpectrogram(
|
||||
sample_rate=sampling_rate,
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
f_min=f_min,
|
||||
f_max=f_max,
|
||||
n_mels=n_mels,
|
||||
)
|
||||
self.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, mel):
|
||||
y = self.backbone(mel)
|
||||
y = self.head(y)
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, x):
|
||||
return self.mel_transform(x)
|
||||
|
||||
def forward(self, mel):
|
||||
y = self.backbone(mel)
|
||||
y = self.head(y)
|
||||
return y
|
||||
@@ -75,16 +75,10 @@ class SnakeBeta(nn.Module):
|
||||
return x
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||
if activation == "elu":
|
||||
|
||||
@@ -228,6 +228,7 @@ class HunyuanVideo(nn.Module):
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
@@ -238,6 +239,14 @@ class HunyuanVideo(nn.Module):
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||
|
||||
if ref_latent is not None:
|
||||
ref_latent_ids = self.img_ids(ref_latent)
|
||||
ref_latent = self.img_in(ref_latent)
|
||||
img = torch.cat([ref_latent, img], dim=-2)
|
||||
ref_latent_ids[..., 0] = -1
|
||||
ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1])
|
||||
img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2)
|
||||
|
||||
if guiding_frame_index is not None:
|
||||
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
@@ -313,6 +322,8 @@ class HunyuanVideo(nn.Module):
|
||||
img[:, : img_len] += add
|
||||
|
||||
img = img[:, : img_len]
|
||||
if ref_latent is not None:
|
||||
img = img[:, ref_latent.shape[1]:]
|
||||
|
||||
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
@@ -324,7 +335,7 @@ class HunyuanVideo(nn.Module):
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
|
||||
def img_ids(self, x):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
@@ -334,7 +345,11 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
|
||||
@@ -247,6 +247,60 @@ class VaceWanAttentionBlock(WanAttentionBlock):
|
||||
return c_skip, c
|
||||
|
||||
|
||||
class WanCamAdapter(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}):
|
||||
super(WanCamAdapter, self).__init__()
|
||||
|
||||
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
|
||||
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
|
||||
|
||||
# Convolution: reduce spatial dimensions by a factor
|
||||
# of 2 (without overlap)
|
||||
self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
# Residual blocks for feature extraction
|
||||
self.residual_blocks = nn.Sequential(
|
||||
*[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Reshape to merge the frame dimension into batch
|
||||
bs, c, f, h, w = x.size()
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
|
||||
|
||||
# Pixel Unshuffle operation
|
||||
x_unshuffled = self.pixel_unshuffle(x)
|
||||
|
||||
# Convolution operation
|
||||
x_conv = self.conv(x_unshuffled)
|
||||
|
||||
# Feature extraction with residual blocks
|
||||
out = self.residual_blocks(x_conv)
|
||||
|
||||
# Reshape to restore original bf dimension
|
||||
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
|
||||
|
||||
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class WanCamResidualBlock(nn.Module):
|
||||
def __init__(self, dim, operation_settings={}):
|
||||
super(WanCamResidualBlock, self).__init__()
|
||||
self.conv1 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.relu(self.conv1(x))
|
||||
out = self.conv2(out)
|
||||
out += residual
|
||||
return out
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
|
||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
||||
@@ -637,3 +691,92 @@ class VaceWanModel(WanModel):
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
class CameraWanModel(WanModel):
|
||||
r"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='camera',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
flf_pos_embed_token_number=None,
|
||||
image_model=None,
|
||||
in_dim_control_adapter=24,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
|
||||
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
clip_fea=None,
|
||||
freqs=None,
|
||||
camera_conditions = None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if self.control_adapter is not None and camera_conditions is not None:
|
||||
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
|
||||
x = x + x_camera
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None:
|
||||
if self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
@@ -286,6 +286,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
|
||||
|
||||
if isinstance(model, comfy.model_base.ACEStep):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ import comfy.ldm.wan.model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.ace.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -923,6 +924,10 @@ class HunyuanVideo(BaseModel):
|
||||
if guiding_frame_index is not None:
|
||||
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
|
||||
|
||||
ref_latent = kwargs.get("ref_latent", None)
|
||||
if ref_latent is not None:
|
||||
out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent))
|
||||
|
||||
return out
|
||||
|
||||
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||
@@ -1074,6 +1079,17 @@ class WAN21_Vace(WAN21):
|
||||
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||
return out
|
||||
|
||||
class WAN21_Camera(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
camera_conditions = kwargs.get("camera_conditions", None)
|
||||
if camera_conditions is not None:
|
||||
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
@@ -1111,7 +1127,7 @@ class HiDream(BaseModel):
|
||||
return out
|
||||
|
||||
class Chroma(Flux):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
@@ -1121,3 +1137,22 @@ class Chroma(Flux):
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class ACEStep(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
||||
if cross_attn is not None:
|
||||
out['lyric_token_idx'] = comfy.conds.CONDRegular(conditioning_lyrics)
|
||||
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
|
||||
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
||||
return out
|
||||
|
||||
@@ -222,10 +222,39 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ltxv"
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
|
||||
dit_config["attention_head_dim"] = shape[0] // 32
|
||||
dit_config["cross_attention_dim"] = shape[1]
|
||||
if metadata is not None and "config" in metadata:
|
||||
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
||||
return dit_config
|
||||
|
||||
if '{}genre_embedder.weight'.format(key_prefix) in state_dict_keys: #ACE-Step model
|
||||
dit_config = {}
|
||||
dit_config["audio_model"] = "ace"
|
||||
dit_config["attention_head_dim"] = 128
|
||||
dit_config["in_channels"] = 8
|
||||
dit_config["inner_dim"] = 2560
|
||||
dit_config["max_height"] = 16
|
||||
dit_config["max_position"] = 32768
|
||||
dit_config["max_width"] = 32768
|
||||
dit_config["mlp_ratio"] = 2.5
|
||||
dit_config["num_attention_heads"] = 20
|
||||
dit_config["num_layers"] = 24
|
||||
dit_config["out_channels"] = 8
|
||||
dit_config["patch_size"] = [16, 1]
|
||||
dit_config["rope_theta"] = 1000000.0
|
||||
dit_config["speaker_embedding_dim"] = 512
|
||||
dit_config["text_embedding_dim"] = 768
|
||||
|
||||
dit_config["ssl_encoder_depths"] = [8, 8]
|
||||
dit_config["ssl_latent_dims"] = [1024, 768]
|
||||
dit_config["ssl_names"] = ["mert", "m-hubert"]
|
||||
dit_config["lyric_encoder_vocab_size"] = 6693
|
||||
dit_config["lyric_hidden_size"] = 1024
|
||||
return dit_config
|
||||
|
||||
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
|
||||
patch_size = 2
|
||||
dit_config = {}
|
||||
@@ -332,6 +361,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "vace"
|
||||
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "camera"
|
||||
else:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
|
||||
@@ -308,10 +308,10 @@ def fp8_linear(self, input):
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
|
||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
|
||||
@@ -30,7 +30,7 @@ if RMSNorm is None:
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape,
|
||||
eps=None,
|
||||
eps=1e-6,
|
||||
elementwise_affine=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
|
||||
60
comfy/sd.py
60
comfy/sd.py
@@ -15,6 +15,7 @@ import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import yaml
|
||||
import math
|
||||
|
||||
@@ -42,6 +43,7 @@ import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.hidream
|
||||
import comfy.text_encoders.ace
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -280,6 +282,7 @@ class VAE:
|
||||
|
||||
self.downscale_index_formula = None
|
||||
self.upscale_index_formula = None
|
||||
self.extra_1d_channel = None
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
@@ -437,6 +440,20 @@ class VAE:
|
||||
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
|
||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
||||
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
|
||||
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 8
|
||||
self.output_channels = 2
|
||||
self.upscale_ratio = 4096
|
||||
self.downscale_ratio = 4096
|
||||
self.latent_dim = 2
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.disable_offload = True
|
||||
self.extra_1d_channel = 16
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@@ -495,7 +512,13 @@ class VAE:
|
||||
return output
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
if samples.ndim == 3:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
else:
|
||||
og_shape = samples.shape
|
||||
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
|
||||
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
@@ -515,9 +538,24 @@ class VAE:
|
||||
samples /= 3.0
|
||||
return samples
|
||||
|
||||
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
||||
if self.latent_dim == 1:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
out_channels = self.latent_channels
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
else:
|
||||
extra_channel_size = self.extra_1d_channel
|
||||
out_channels = self.latent_channels * extra_channel_size
|
||||
tile_x = tile_x // extra_channel_size
|
||||
overlap = overlap // extra_channel_size
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
|
||||
|
||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||
if self.latent_dim == 1:
|
||||
return out
|
||||
else:
|
||||
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
|
||||
|
||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
@@ -542,7 +580,7 @@ class VAE:
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1:
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
@@ -609,7 +647,7 @@ class VAE:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1:
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
@@ -715,6 +753,7 @@ class CLIPType(Enum):
|
||||
WAN = 13
|
||||
HIDREAM = 14
|
||||
CHROMA = 15
|
||||
ACE = 16
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@@ -840,8 +879,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||
elif te_model == TEModel.T5_BASE:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
if clip_type == CLIPType.ACE or "spiece_model" in clip_data[0]:
|
||||
clip_target.clip = comfy.text_encoders.ace.AceT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.ace.AceT5Tokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
elif te_model == TEModel.GEMMA_2_2B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
|
||||
@@ -17,6 +17,7 @@ import comfy.text_encoders.hunyuan_video
|
||||
import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.ace
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -785,6 +786,10 @@ class LTXV(supported_models_base.BASE):
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("cross_attention_dim", 2048) / 2048) * 5.5
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXV(self, device=device)
|
||||
return out
|
||||
@@ -987,6 +992,16 @@ class WAN21_FunControl2V(WAN21_T2V):
|
||||
out = model_base.WAN21(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN21_Camera(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "camera",
|
||||
"in_dim": 32,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
|
||||
return out
|
||||
class WAN21_Vace(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@@ -1096,6 +1111,34 @@ class Chroma(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma]
|
||||
class ACEStep(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"audio_model": "ace",
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
latent_format = comfy.latent_formats.ACEAudio
|
||||
|
||||
memory_usage_factor = 0.5
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.ACEStep(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
153
comfy/text_encoders/ace.py
Normal file
153
comfy/text_encoders/ace.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from comfy import sd1_clip
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.t5
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from tokenizers import Tokenizer
|
||||
from .ace_text_cleaners import multilingual_cleaners, japanese_to_romaji
|
||||
|
||||
SUPPORT_LANGUAGES = {
|
||||
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
|
||||
"pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
|
||||
"nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
|
||||
"ko": 6152, "hi": 6680
|
||||
}
|
||||
|
||||
structure_pattern = re.compile(r"\[.*?\]")
|
||||
|
||||
DEFAULT_VOCAB_FILE = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
|
||||
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
|
||||
self.tokenizer = None
|
||||
if vocab_file is not None:
|
||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||
|
||||
def preprocess_text(self, txt, lang):
|
||||
txt = multilingual_cleaners(txt, lang)
|
||||
return txt
|
||||
|
||||
def encode(self, txt, lang='en'):
|
||||
# lang = lang.split("-")[0] # remove the region
|
||||
# self.check_input_length(txt, lang)
|
||||
txt = self.preprocess_text(txt, lang)
|
||||
lang = "zh-cn" if lang == "zh" else lang
|
||||
txt = f"[{lang}]{txt}"
|
||||
txt = txt.replace(" ", "[SPACE]")
|
||||
return self.tokenizer.encode(txt).ids
|
||||
|
||||
def get_lang(self, line):
|
||||
if line.startswith("[") and line[3:4] == ']':
|
||||
lang = line[1:3].lower()
|
||||
if lang in SUPPORT_LANGUAGES:
|
||||
return lang, line[4:]
|
||||
return "en", line
|
||||
|
||||
def __call__(self, string):
|
||||
lines = string.split("\n")
|
||||
lyric_token_idx = [261]
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
lyric_token_idx += [2]
|
||||
continue
|
||||
|
||||
lang, line = self.get_lang(line)
|
||||
|
||||
if lang not in SUPPORT_LANGUAGES:
|
||||
lang = "en"
|
||||
if "zh" in lang:
|
||||
lang = "zh"
|
||||
if "spa" in lang:
|
||||
lang = "es"
|
||||
|
||||
try:
|
||||
line_out = japanese_to_romaji(line)
|
||||
if line_out != line:
|
||||
lang = "ja"
|
||||
line = line_out
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if structure_pattern.match(line):
|
||||
token_idx = self.encode(line, "en")
|
||||
else:
|
||||
token_idx = self.encode(line, lang)
|
||||
lyric_token_idx = lyric_token_idx + token_idx + [2]
|
||||
except Exception as e:
|
||||
logging.warning("tokenize error {} for line {} major_language {}".format(e, line, lang))
|
||||
return {"input_ids": lyric_token_idx}
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path, **kwargs):
|
||||
return VoiceBpeTokenizer(path, **kwargs)
|
||||
|
||||
def get_vocab(self):
|
||||
return {}
|
||||
|
||||
|
||||
class UMT5BaseModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_base.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)
|
||||
|
||||
class UMT5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=768, embedding_key='umt5base', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
class LyricsTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='lyrics', tokenizer_class=VoiceBpeTokenizer, has_start_token=True, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=2, has_end_token=False, tokenizer_data=tokenizer_data)
|
||||
|
||||
class AceT5Tokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.voicebpe = LyricsTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.umt5base = UMT5BaseTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["lyrics"] = self.voicebpe.tokenize_with_weights(kwargs.get("lyrics", ""), return_word_ids, **kwargs)
|
||||
out["umt5base"] = self.umt5base.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.umt5base.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return self.umt5base.state_dict()
|
||||
|
||||
class AceT5Model(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__()
|
||||
self.umt5base = UMT5BaseModel(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes = set()
|
||||
if dtype is not None:
|
||||
self.dtypes.add(dtype)
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.umt5base.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.umt5base.reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_umt5base = token_weight_pairs["umt5base"]
|
||||
token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
|
||||
|
||||
t5_out, t5_pooled = self.umt5base.encode_token_weights(token_weight_pairs_umt5base)
|
||||
|
||||
lyrics_embeds = torch.tensor(list(map(lambda a: a[0], token_weight_pairs_lyrics[0]))).unsqueeze(0)
|
||||
return t5_out, None, {"conditioning_lyrics": lyrics_embeds}
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.umt5base.load_sd(sd)
|
||||
15535
comfy/text_encoders/ace_lyrics_tokenizer/vocab.json
Normal file
15535
comfy/text_encoders/ace_lyrics_tokenizer/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
395
comfy/text_encoders/ace_text_cleaners.py
Normal file
395
comfy/text_encoders/ace_text_cleaners.py
Normal file
@@ -0,0 +1,395 @@
|
||||
# basic text cleaners for the ACE step model
|
||||
# I didn't copy the ones from the reference code because I didn't want to deal with the dependencies
|
||||
# TODO: more languages than english?
|
||||
|
||||
import re
|
||||
|
||||
def japanese_to_romaji(japanese_text):
|
||||
"""
|
||||
Convert Japanese hiragana and katakana to romaji (Latin alphabet representation).
|
||||
|
||||
Args:
|
||||
japanese_text (str): Text containing hiragana and/or katakana characters
|
||||
|
||||
Returns:
|
||||
str: The romaji (Latin alphabet) equivalent
|
||||
"""
|
||||
# Dictionary mapping kana characters to their romaji equivalents
|
||||
kana_map = {
|
||||
# Katakana characters
|
||||
'ア': 'a', 'イ': 'i', 'ウ': 'u', 'エ': 'e', 'オ': 'o',
|
||||
'カ': 'ka', 'キ': 'ki', 'ク': 'ku', 'ケ': 'ke', 'コ': 'ko',
|
||||
'サ': 'sa', 'シ': 'shi', 'ス': 'su', 'セ': 'se', 'ソ': 'so',
|
||||
'タ': 'ta', 'チ': 'chi', 'ツ': 'tsu', 'テ': 'te', 'ト': 'to',
|
||||
'ナ': 'na', 'ニ': 'ni', 'ヌ': 'nu', 'ネ': 'ne', 'ノ': 'no',
|
||||
'ハ': 'ha', 'ヒ': 'hi', 'フ': 'fu', 'ヘ': 'he', 'ホ': 'ho',
|
||||
'マ': 'ma', 'ミ': 'mi', 'ム': 'mu', 'メ': 'me', 'モ': 'mo',
|
||||
'ヤ': 'ya', 'ユ': 'yu', 'ヨ': 'yo',
|
||||
'ラ': 'ra', 'リ': 'ri', 'ル': 'ru', 'レ': 're', 'ロ': 'ro',
|
||||
'ワ': 'wa', 'ヲ': 'wo', 'ン': 'n',
|
||||
|
||||
# Katakana voiced consonants
|
||||
'ガ': 'ga', 'ギ': 'gi', 'グ': 'gu', 'ゲ': 'ge', 'ゴ': 'go',
|
||||
'ザ': 'za', 'ジ': 'ji', 'ズ': 'zu', 'ゼ': 'ze', 'ゾ': 'zo',
|
||||
'ダ': 'da', 'ヂ': 'ji', 'ヅ': 'zu', 'デ': 'de', 'ド': 'do',
|
||||
'バ': 'ba', 'ビ': 'bi', 'ブ': 'bu', 'ベ': 'be', 'ボ': 'bo',
|
||||
'パ': 'pa', 'ピ': 'pi', 'プ': 'pu', 'ペ': 'pe', 'ポ': 'po',
|
||||
|
||||
# Katakana combinations
|
||||
'キャ': 'kya', 'キュ': 'kyu', 'キョ': 'kyo',
|
||||
'シャ': 'sha', 'シュ': 'shu', 'ショ': 'sho',
|
||||
'チャ': 'cha', 'チュ': 'chu', 'チョ': 'cho',
|
||||
'ニャ': 'nya', 'ニュ': 'nyu', 'ニョ': 'nyo',
|
||||
'ヒャ': 'hya', 'ヒュ': 'hyu', 'ヒョ': 'hyo',
|
||||
'ミャ': 'mya', 'ミュ': 'myu', 'ミョ': 'myo',
|
||||
'リャ': 'rya', 'リュ': 'ryu', 'リョ': 'ryo',
|
||||
'ギャ': 'gya', 'ギュ': 'gyu', 'ギョ': 'gyo',
|
||||
'ジャ': 'ja', 'ジュ': 'ju', 'ジョ': 'jo',
|
||||
'ビャ': 'bya', 'ビュ': 'byu', 'ビョ': 'byo',
|
||||
'ピャ': 'pya', 'ピュ': 'pyu', 'ピョ': 'pyo',
|
||||
|
||||
# Katakana small characters and special cases
|
||||
'ッ': '', # Small tsu (doubles the following consonant)
|
||||
'ャ': 'ya', 'ュ': 'yu', 'ョ': 'yo',
|
||||
|
||||
# Katakana extras
|
||||
'ヴ': 'vu', 'ファ': 'fa', 'フィ': 'fi', 'フェ': 'fe', 'フォ': 'fo',
|
||||
'ウィ': 'wi', 'ウェ': 'we', 'ウォ': 'wo',
|
||||
|
||||
# Hiragana characters
|
||||
'あ': 'a', 'い': 'i', 'う': 'u', 'え': 'e', 'お': 'o',
|
||||
'か': 'ka', 'き': 'ki', 'く': 'ku', 'け': 'ke', 'こ': 'ko',
|
||||
'さ': 'sa', 'し': 'shi', 'す': 'su', 'せ': 'se', 'そ': 'so',
|
||||
'た': 'ta', 'ち': 'chi', 'つ': 'tsu', 'て': 'te', 'と': 'to',
|
||||
'な': 'na', 'に': 'ni', 'ぬ': 'nu', 'ね': 'ne', 'の': 'no',
|
||||
'は': 'ha', 'ひ': 'hi', 'ふ': 'fu', 'へ': 'he', 'ほ': 'ho',
|
||||
'ま': 'ma', 'み': 'mi', 'む': 'mu', 'め': 'me', 'も': 'mo',
|
||||
'や': 'ya', 'ゆ': 'yu', 'よ': 'yo',
|
||||
'ら': 'ra', 'り': 'ri', 'る': 'ru', 'れ': 're', 'ろ': 'ro',
|
||||
'わ': 'wa', 'を': 'wo', 'ん': 'n',
|
||||
|
||||
# Hiragana voiced consonants
|
||||
'が': 'ga', 'ぎ': 'gi', 'ぐ': 'gu', 'げ': 'ge', 'ご': 'go',
|
||||
'ざ': 'za', 'じ': 'ji', 'ず': 'zu', 'ぜ': 'ze', 'ぞ': 'zo',
|
||||
'だ': 'da', 'ぢ': 'ji', 'づ': 'zu', 'で': 'de', 'ど': 'do',
|
||||
'ば': 'ba', 'び': 'bi', 'ぶ': 'bu', 'べ': 'be', 'ぼ': 'bo',
|
||||
'ぱ': 'pa', 'ぴ': 'pi', 'ぷ': 'pu', 'ぺ': 'pe', 'ぽ': 'po',
|
||||
|
||||
# Hiragana combinations
|
||||
'きゃ': 'kya', 'きゅ': 'kyu', 'きょ': 'kyo',
|
||||
'しゃ': 'sha', 'しゅ': 'shu', 'しょ': 'sho',
|
||||
'ちゃ': 'cha', 'ちゅ': 'chu', 'ちょ': 'cho',
|
||||
'にゃ': 'nya', 'にゅ': 'nyu', 'にょ': 'nyo',
|
||||
'ひゃ': 'hya', 'ひゅ': 'hyu', 'ひょ': 'hyo',
|
||||
'みゃ': 'mya', 'みゅ': 'myu', 'みょ': 'myo',
|
||||
'りゃ': 'rya', 'りゅ': 'ryu', 'りょ': 'ryo',
|
||||
'ぎゃ': 'gya', 'ぎゅ': 'gyu', 'ぎょ': 'gyo',
|
||||
'じゃ': 'ja', 'じゅ': 'ju', 'じょ': 'jo',
|
||||
'びゃ': 'bya', 'びゅ': 'byu', 'びょ': 'byo',
|
||||
'ぴゃ': 'pya', 'ぴゅ': 'pyu', 'ぴょ': 'pyo',
|
||||
|
||||
# Hiragana small characters and special cases
|
||||
'っ': '', # Small tsu (doubles the following consonant)
|
||||
'ゃ': 'ya', 'ゅ': 'yu', 'ょ': 'yo',
|
||||
|
||||
# Common punctuation and spaces
|
||||
' ': ' ', # Japanese space
|
||||
'、': ', ', '。': '. ',
|
||||
}
|
||||
|
||||
result = []
|
||||
i = 0
|
||||
|
||||
while i < len(japanese_text):
|
||||
# Check for small tsu (doubling the following consonant)
|
||||
if i < len(japanese_text) - 1 and (japanese_text[i] == 'っ' or japanese_text[i] == 'ッ'):
|
||||
if i < len(japanese_text) - 1 and japanese_text[i+1] in kana_map:
|
||||
next_romaji = kana_map[japanese_text[i+1]]
|
||||
if next_romaji and next_romaji[0] not in 'aiueon':
|
||||
result.append(next_romaji[0]) # Double the consonant
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Check for combinations with small ya, yu, yo
|
||||
if i < len(japanese_text) - 1 and japanese_text[i+1] in ('ゃ', 'ゅ', 'ょ', 'ャ', 'ュ', 'ョ'):
|
||||
combo = japanese_text[i:i+2]
|
||||
if combo in kana_map:
|
||||
result.append(kana_map[combo])
|
||||
i += 2
|
||||
continue
|
||||
|
||||
# Regular character
|
||||
if japanese_text[i] in kana_map:
|
||||
result.append(kana_map[japanese_text[i]])
|
||||
else:
|
||||
# If it's not in our map, keep it as is (might be kanji, romaji, etc.)
|
||||
result.append(japanese_text[i])
|
||||
|
||||
i += 1
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
def number_to_text(num, ordinal=False):
|
||||
"""
|
||||
Convert a number (int or float) to its text representation.
|
||||
|
||||
Args:
|
||||
num: The number to convert
|
||||
|
||||
Returns:
|
||||
str: Text representation of the number
|
||||
"""
|
||||
|
||||
if not isinstance(num, (int, float)):
|
||||
return "Input must be a number"
|
||||
|
||||
# Handle special case of zero
|
||||
if num == 0:
|
||||
return "zero"
|
||||
|
||||
# Handle negative numbers
|
||||
negative = num < 0
|
||||
num = abs(num)
|
||||
|
||||
# Handle floats
|
||||
if isinstance(num, float):
|
||||
# Split into integer and decimal parts
|
||||
int_part = int(num)
|
||||
|
||||
# Convert both parts
|
||||
int_text = _int_to_text(int_part)
|
||||
|
||||
# Handle decimal part (convert to string and remove '0.')
|
||||
decimal_str = str(num).split('.')[1]
|
||||
decimal_text = " point " + " ".join(_digit_to_text(int(digit)) for digit in decimal_str)
|
||||
|
||||
result = int_text + decimal_text
|
||||
else:
|
||||
# Handle integers
|
||||
result = _int_to_text(num)
|
||||
|
||||
# Add 'negative' prefix for negative numbers
|
||||
if negative:
|
||||
result = "negative " + result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _int_to_text(num):
|
||||
"""Helper function to convert an integer to text"""
|
||||
|
||||
ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
|
||||
"ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen",
|
||||
"seventeen", "eighteen", "nineteen"]
|
||||
|
||||
tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
|
||||
|
||||
if num < 20:
|
||||
return ones[num]
|
||||
|
||||
if num < 100:
|
||||
return tens[num // 10] + (" " + ones[num % 10] if num % 10 != 0 else "")
|
||||
|
||||
if num < 1000:
|
||||
return ones[num // 100] + " hundred" + (" " + _int_to_text(num % 100) if num % 100 != 0 else "")
|
||||
|
||||
if num < 1000000:
|
||||
return _int_to_text(num // 1000) + " thousand" + (" " + _int_to_text(num % 1000) if num % 1000 != 0 else "")
|
||||
|
||||
if num < 1000000000:
|
||||
return _int_to_text(num // 1000000) + " million" + (" " + _int_to_text(num % 1000000) if num % 1000000 != 0 else "")
|
||||
|
||||
return _int_to_text(num // 1000000000) + " billion" + (" " + _int_to_text(num % 1000000000) if num % 1000000000 != 0 else "")
|
||||
|
||||
|
||||
def _digit_to_text(digit):
|
||||
"""Convert a single digit to text"""
|
||||
digits = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
||||
return digits[digit]
|
||||
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = {
|
||||
"en": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
("dr", "doctor"),
|
||||
("st", "saint"),
|
||||
("co", "company"),
|
||||
("jr", "junior"),
|
||||
("maj", "major"),
|
||||
("gen", "general"),
|
||||
("drs", "doctors"),
|
||||
("rev", "reverend"),
|
||||
("lt", "lieutenant"),
|
||||
("hon", "honorable"),
|
||||
("sgt", "sergeant"),
|
||||
("capt", "captain"),
|
||||
("esq", "esquire"),
|
||||
("ltd", "limited"),
|
||||
("col", "colonel"),
|
||||
("ft", "fort"),
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def expand_abbreviations_multilingual(text, lang="en"):
|
||||
for regex, replacement in _abbreviations[lang]:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
_symbols_multilingual = {
|
||||
"en": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " and "),
|
||||
("@", " at "),
|
||||
("%", " percent "),
|
||||
("#", " hash "),
|
||||
("$", " dollar "),
|
||||
("£", " pound "),
|
||||
("°", " degree "),
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def expand_symbols_multilingual(text, lang="en"):
|
||||
for regex, replacement in _symbols_multilingual[lang]:
|
||||
text = re.sub(regex, replacement, text)
|
||||
text = text.replace(" ", " ") # Ensure there are no double spaces
|
||||
return text.strip()
|
||||
|
||||
|
||||
_ordinal_re = {
|
||||
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
|
||||
}
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
_currency_re = {
|
||||
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
|
||||
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
|
||||
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
|
||||
}
|
||||
|
||||
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
|
||||
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
|
||||
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
text = m.group(0)
|
||||
if "," in text:
|
||||
text = text.replace(",", "")
|
||||
return text
|
||||
|
||||
|
||||
def _remove_dots(m):
|
||||
text = m.group(0)
|
||||
if "." in text:
|
||||
text = text.replace(".", "")
|
||||
return text
|
||||
|
||||
|
||||
def _expand_decimal_point(m, lang="en"):
|
||||
amount = m.group(1).replace(",", ".")
|
||||
return number_to_text(float(amount))
|
||||
|
||||
|
||||
def _expand_currency(m, lang="en", currency="USD"):
|
||||
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
|
||||
full_amount = number_to_text(amount)
|
||||
|
||||
and_equivalents = {
|
||||
"en": ", ",
|
||||
"es": " con ",
|
||||
"fr": " et ",
|
||||
"de": " und ",
|
||||
"pt": " e ",
|
||||
"it": " e ",
|
||||
"pl": ", ",
|
||||
"cs": ", ",
|
||||
"ru": ", ",
|
||||
"nl": ", ",
|
||||
"ar": ", ",
|
||||
"tr": ", ",
|
||||
"hu": ", ",
|
||||
"ko": ", ",
|
||||
}
|
||||
|
||||
if amount.is_integer():
|
||||
last_and = full_amount.rfind(and_equivalents[lang])
|
||||
if last_and != -1:
|
||||
full_amount = full_amount[:last_and]
|
||||
|
||||
return full_amount
|
||||
|
||||
|
||||
def _expand_ordinal(m, lang="en"):
|
||||
return number_to_text(int(m.group(1)), ordinal=True)
|
||||
|
||||
|
||||
def _expand_number(m, lang="en"):
|
||||
return number_to_text(int(m.group(0)))
|
||||
|
||||
|
||||
def expand_numbers_multilingual(text, lang="en"):
|
||||
if lang in ["en", "ru"]:
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
else:
|
||||
text = re.sub(_dot_number_re, _remove_dots, text)
|
||||
try:
|
||||
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
|
||||
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
|
||||
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
|
||||
except:
|
||||
pass
|
||||
|
||||
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
|
||||
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
|
||||
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
|
||||
return text
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def multilingual_cleaners(text, lang):
|
||||
text = text.replace('"', "")
|
||||
if lang == "tr":
|
||||
text = text.replace("İ", "i")
|
||||
text = text.replace("Ö", "ö")
|
||||
text = text.replace("Ü", "ü")
|
||||
text = lowercase(text)
|
||||
try:
|
||||
text = expand_numbers_multilingual(text, lang)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
text = expand_abbreviations_multilingual(text, lang)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
text = expand_symbols_multilingual(text, lang=lang)
|
||||
except:
|
||||
pass
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
22
comfy/text_encoders/umt5_config_base.json
Normal file
22
comfy/text_encoders/umt5_config_base.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"d_ff": 2048,
|
||||
"d_kv": 64,
|
||||
"d_model": 768,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"dense_act_fn": "gelu_pytorch_tanh",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "umt5",
|
||||
"num_decoder_layers": 12,
|
||||
"num_heads": 12,
|
||||
"num_layers": 12,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"vocab_size": 256384
|
||||
}
|
||||
@@ -28,6 +28,9 @@ import logging
|
||||
import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
|
||||
ALWAYS_SAFE_LOAD = False
|
||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||
@@ -67,12 +70,14 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
|
||||
raise e
|
||||
else:
|
||||
torch_args = {}
|
||||
if MMAP_TORCH_FILES:
|
||||
torch_args["mmap"] = True
|
||||
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||
if "global_step" in pl_sd:
|
||||
logging.debug(f"Global Step: {pl_sd['global_step']}")
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
|
||||
@@ -43,3 +43,13 @@ class VideoInput(ABC):
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
components = self.get_components()
|
||||
frame_count = components.images.shape[0]
|
||||
return float(frame_count / components.frame_rate)
|
||||
|
||||
@@ -80,6 +80,38 @@ class VideoFromFile(VideoInput):
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
if container.duration is not None:
|
||||
return float(container.duration / av.time_base)
|
||||
|
||||
# Fallback: calculate from frame count and frame rate
|
||||
video_stream = next(
|
||||
(s for s in container.streams if s.type == "video"), None
|
||||
)
|
||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||
return float(video_stream.frames / video_stream.average_rate)
|
||||
|
||||
# Last resort: decode frames to count them
|
||||
if video_stream and video_stream.average_rate:
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
# Get video frames
|
||||
frames = []
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.util import VideoContainer, VideoCodec
|
||||
@@ -14,6 +15,7 @@ from comfy_api_nodes.apis.client import (
|
||||
UploadRequest,
|
||||
UploadResponse,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
import numpy as np
|
||||
@@ -59,7 +61,9 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||
return s
|
||||
|
||||
|
||||
def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
|
||||
def validate_and_cast_response(
|
||||
response, timeout: int = None, node_id: Union[str, None] = None
|
||||
) -> torch.Tensor:
|
||||
"""Validates and casts a response to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
@@ -93,6 +97,10 @@ def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
|
||||
elif image_url:
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {image_url}", node_id
|
||||
)
|
||||
img_response = requests.get(image_url, timeout=timeout)
|
||||
if img_response.status_code != 200:
|
||||
raise ValueError("Failed to download the image")
|
||||
@@ -314,7 +322,7 @@ def upload_file_to_comfyapi(
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: str,
|
||||
auth_token: Optional[str] = None,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single file to ComfyUI API and returns its download URL.
|
||||
@@ -323,7 +331,7 @@ def upload_file_to_comfyapi(
|
||||
file_bytes_io: BytesIO object containing the file data.
|
||||
filename: The filename of the file.
|
||||
upload_mime_type: MIME type of the file.
|
||||
auth_token: Optional authentication token.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded file.
|
||||
@@ -337,7 +345,7 @@ def upload_file_to_comfyapi(
|
||||
response_model=UploadResponse,
|
||||
),
|
||||
request=request_object,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
response: UploadResponse = operation.execute()
|
||||
@@ -351,7 +359,7 @@ def upload_file_to_comfyapi(
|
||||
|
||||
def upload_video_to_comfyapi(
|
||||
video: VideoInput,
|
||||
auth_token: Optional[str] = None,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
codec: VideoCodec = VideoCodec.H264,
|
||||
max_duration: Optional[int] = None,
|
||||
@@ -362,7 +370,7 @@ def upload_video_to_comfyapi(
|
||||
|
||||
Args:
|
||||
video: VideoInput object (Comfy VIDEO type).
|
||||
auth_token: Optional authentication token.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
container: The video container format to use (default: MP4).
|
||||
codec: The video codec to use (default: H264).
|
||||
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
|
||||
@@ -390,7 +398,7 @@ def upload_video_to_comfyapi(
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return upload_file_to_comfyapi(
|
||||
video_bytes_io, filename, upload_mime_type, auth_token
|
||||
video_bytes_io, filename, upload_mime_type, auth_kwargs
|
||||
)
|
||||
|
||||
|
||||
@@ -453,7 +461,7 @@ def audio_ndarray_to_bytesio(
|
||||
|
||||
def upload_audio_to_comfyapi(
|
||||
audio: AudioInput,
|
||||
auth_token: Optional[str] = None,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
mime_type: str = "audio/mp4",
|
||||
@@ -465,7 +473,7 @@ def upload_audio_to_comfyapi(
|
||||
|
||||
Args:
|
||||
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
|
||||
auth_token: Optional authentication token.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded audio file.
|
||||
@@ -477,11 +485,11 @@ def upload_audio_to_comfyapi(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
|
||||
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_token)
|
||||
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def upload_images_to_comfyapi(
|
||||
image: torch.Tensor, max_images=8, auth_token=None, mime_type: Optional[str] = None
|
||||
image: torch.Tensor, max_images=8, auth_kwargs: Optional[dict[str,str]] = None, mime_type: Optional[str] = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
@@ -490,7 +498,7 @@ def upload_images_to_comfyapi(
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
max_images: Maximum number of images to upload.
|
||||
auth_token: Optional authentication token.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
mime_type: Optional MIME type for the image.
|
||||
"""
|
||||
# if batch, try to upload each file if max_images is greater than 0
|
||||
@@ -521,7 +529,7 @@ def upload_images_to_comfyapi(
|
||||
response_model=UploadResponse,
|
||||
),
|
||||
request=request_object,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response = operation.execute()
|
||||
|
||||
|
||||
@@ -20,7 +20,8 @@ Usage Examples:
|
||||
# 1. Create the API client
|
||||
api_client = ApiClient(
|
||||
base_url="https://api.example.com",
|
||||
api_key="your_api_key_here",
|
||||
auth_token="your_auth_token_here",
|
||||
comfy_api_key="your_comfy_api_key_here",
|
||||
timeout=30.0,
|
||||
verify_ssl=True
|
||||
)
|
||||
@@ -93,15 +94,19 @@ from __future__ import annotations
|
||||
import logging
|
||||
import time
|
||||
import io
|
||||
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable
|
||||
import socket
|
||||
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
|
||||
from enum import Enum
|
||||
import json
|
||||
import requests
|
||||
from urllib.parse import urljoin
|
||||
from urllib.parse import urljoin, urlparse
|
||||
from pydantic import BaseModel, Field
|
||||
import uuid # For generating unique operation IDs
|
||||
|
||||
from server import PromptServer
|
||||
from comfy.cli_args import args
|
||||
from comfy import utils
|
||||
from . import request_logger
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
R = TypeVar("R", bound=BaseModel)
|
||||
@@ -110,6 +115,21 @@ P = TypeVar("P", bound=BaseModel) # For poll response
|
||||
PROGRESS_BAR_MAX = 100
|
||||
|
||||
|
||||
class NetworkError(Exception):
|
||||
"""Base exception for network-related errors with diagnostic information."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalNetworkError(NetworkError):
|
||||
"""Exception raised when local network connectivity issues are detected."""
|
||||
pass
|
||||
|
||||
|
||||
class ApiServerError(NetworkError):
|
||||
"""Exception raised when the API server is unreachable but internet is working."""
|
||||
pass
|
||||
|
||||
|
||||
class EmptyRequest(BaseModel):
|
||||
"""Base class for empty request bodies.
|
||||
For GET requests, fields will be sent as query parameters."""
|
||||
@@ -140,20 +160,36 @@ class HttpMethod(str, Enum):
|
||||
|
||||
class ApiClient:
|
||||
"""
|
||||
Client for making HTTP requests to an API with authentication and error handling.
|
||||
Client for making HTTP requests to an API with authentication, error handling, and retry logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: Optional[str] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
timeout: float = 3600.0,
|
||||
verify_ssl: bool = True,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
retry_status_codes: Optional[Tuple[int, ...]] = None,
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
self.timeout = timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
|
||||
# 500, 502, 503, 504 (Server Errors)
|
||||
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
|
||||
|
||||
def _generate_operation_id(self, path: str) -> str:
|
||||
"""Generates a unique operation ID for logging."""
|
||||
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
def _create_json_payload_args(
|
||||
self,
|
||||
@@ -201,11 +237,63 @@ class ApiClient:
|
||||
"""Get headers for API requests, including authentication if available"""
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
if self.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||
elif self.comfy_api_key:
|
||||
headers["X-API-KEY"] = self.comfy_api_key
|
||||
|
||||
return headers
|
||||
|
||||
def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
|
||||
"""
|
||||
Check connectivity to determine if network issues are local or server-related.
|
||||
|
||||
Args:
|
||||
target_url: URL to check connectivity to
|
||||
|
||||
Returns:
|
||||
Dictionary with connectivity status details
|
||||
"""
|
||||
results = {
|
||||
"internet_accessible": False,
|
||||
"api_accessible": False,
|
||||
"is_local_issue": False,
|
||||
"is_api_issue": False
|
||||
}
|
||||
|
||||
# First check basic internet connectivity using a reliable external site
|
||||
try:
|
||||
# Use a reliable external domain for checking basic connectivity
|
||||
check_response = requests.get("https://www.google.com",
|
||||
timeout=5.0,
|
||||
verify=self.verify_ssl)
|
||||
if check_response.status_code < 500:
|
||||
results["internet_accessible"] = True
|
||||
except (requests.RequestException, socket.error):
|
||||
results["internet_accessible"] = False
|
||||
results["is_local_issue"] = True
|
||||
return results
|
||||
|
||||
# Now check API server connectivity
|
||||
try:
|
||||
# Extract domain from the target URL to do a simpler health check
|
||||
parsed_url = urlparse(target_url)
|
||||
api_base = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
# Try to reach the API domain
|
||||
api_response = requests.get(f"{api_base}/health", timeout=5.0, verify=self.verify_ssl)
|
||||
if api_response.status_code < 500:
|
||||
results["api_accessible"] = True
|
||||
else:
|
||||
results["api_accessible"] = False
|
||||
results["is_api_issue"] = True
|
||||
except requests.RequestException:
|
||||
results["api_accessible"] = False
|
||||
# If we can reach the internet but not the API, it's an API issue
|
||||
results["is_api_issue"] = True
|
||||
|
||||
return results
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
@@ -216,9 +304,10 @@ class ApiClient:
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
content_type: str = "application/json",
|
||||
multipart_parser: Callable = None,
|
||||
retry_count: int = 0, # Used internally for tracking retries
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Make an HTTP request to the API
|
||||
Make an HTTP request to the API with automatic retries for transient errors.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
@@ -228,15 +317,18 @@ class ApiClient:
|
||||
files: Files to upload
|
||||
headers: Additional headers
|
||||
content_type: Content type of the request. Defaults to application/json.
|
||||
retry_count: Internal parameter for tracking retries, do not set manually
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
|
||||
Raises:
|
||||
requests.RequestException: If the request fails
|
||||
LocalNetworkError: If local network connectivity issues are detected
|
||||
ApiServerError: If the API server is unreachable but internet is working
|
||||
Exception: For other request failures
|
||||
"""
|
||||
url = urljoin(self.base_url, path)
|
||||
self.check_auth_token(self.api_key)
|
||||
self.check_auth(self.auth_token, self.comfy_api_key)
|
||||
# Combine default headers with any provided headers
|
||||
request_headers = self.get_headers()
|
||||
if headers:
|
||||
@@ -260,6 +352,16 @@ class ApiClient:
|
||||
else:
|
||||
payload_args = self._create_json_payload_args(data, request_headers)
|
||||
|
||||
operation_id = self._generate_operation_id(path)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=request_headers,
|
||||
request_params=params,
|
||||
request_data=data if content_type == "application/json" else "[form-data or other]"
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.request(
|
||||
method=method,
|
||||
@@ -270,87 +372,365 @@ class ApiClient:
|
||||
**payload_args,
|
||||
)
|
||||
|
||||
# Check if we should retry based on status code
|
||||
if (response.status_code in self.retry_status_codes and
|
||||
retry_count < self.max_retries):
|
||||
|
||||
# Calculate delay with exponential backoff
|
||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||||
|
||||
logging.warning(
|
||||
f"Request failed with status {response.status_code}. "
|
||||
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
|
||||
)
|
||||
|
||||
time.sleep(delay)
|
||||
return self.request(
|
||||
method=method,
|
||||
path=path,
|
||||
params=params,
|
||||
data=data,
|
||||
files=files,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
multipart_parser=multipart_parser,
|
||||
retry_count=retry_count + 1,
|
||||
)
|
||||
|
||||
# Raise exception for error status codes
|
||||
response.raise_for_status()
|
||||
except requests.ConnectionError:
|
||||
raise Exception(
|
||||
f"Unable to connect to the API server at {self.base_url}. Please check your internet connection or verify the service is available."
|
||||
|
||||
# Log successful response
|
||||
response_content_to_log = response.content
|
||||
try:
|
||||
# Attempt to parse JSON for prettier logging, fallback to raw content
|
||||
response_content_to_log = response.json()
|
||||
except json.JSONDecodeError:
|
||||
pass # Keep as bytes/str if not JSON
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method, # Pass request details again for context in log
|
||||
request_url=url,
|
||||
response_status_code=response.status_code,
|
||||
response_headers=dict(response.headers),
|
||||
response_content=response_content_to_log
|
||||
)
|
||||
|
||||
except requests.Timeout:
|
||||
raise Exception(
|
||||
f"Request timed out after {self.timeout} seconds. The server might be experiencing high load or the operation is taking longer than expected."
|
||||
except requests.ConnectionError as e:
|
||||
error_message = f"ConnectionError: {str(e)}"
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
error_message=error_message
|
||||
)
|
||||
# Only perform connectivity check if we've exhausted all retries
|
||||
if retry_count >= self.max_retries:
|
||||
# Check connectivity to determine if it's a local or API issue
|
||||
connectivity = self._check_connectivity(self.base_url)
|
||||
|
||||
if connectivity["is_local_issue"]:
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
elif connectivity["is_api_issue"]:
|
||||
raise ApiServerError(
|
||||
f"The API server at {self.base_url} is currently unreachable. "
|
||||
f"The service may be experiencing issues. Please try again later."
|
||||
) from e
|
||||
|
||||
# If we haven't exhausted retries yet, retry the request
|
||||
if retry_count < self.max_retries:
|
||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||||
logging.warning(
|
||||
f"Connection error: {str(e)}. "
|
||||
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
|
||||
)
|
||||
time.sleep(delay)
|
||||
return self.request(
|
||||
method=method,
|
||||
path=path,
|
||||
params=params,
|
||||
data=data,
|
||||
files=files,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
multipart_parser=multipart_parser,
|
||||
retry_count=retry_count + 1,
|
||||
)
|
||||
|
||||
# If we've exhausted retries and didn't identify the specific issue,
|
||||
# raise a generic exception
|
||||
final_error_message = (
|
||||
f"Unable to connect to the API server after {self.max_retries} attempts. "
|
||||
f"Please check your internet connection or try again later."
|
||||
)
|
||||
request_logger.log_request_response( # Log final failure
|
||||
operation_id=operation_id,
|
||||
request_method=method, request_url=url,
|
||||
error_message=final_error_message
|
||||
)
|
||||
raise Exception(final_error_message) from e
|
||||
|
||||
except requests.Timeout as e:
|
||||
error_message = f"Timeout: {str(e)}"
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method, request_url=url,
|
||||
error_message=error_message
|
||||
)
|
||||
# Retry timeouts if we haven't exhausted retries
|
||||
if retry_count < self.max_retries:
|
||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||||
logging.warning(
|
||||
f"Request timed out. "
|
||||
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
|
||||
)
|
||||
time.sleep(delay)
|
||||
return self.request(
|
||||
method=method,
|
||||
path=path,
|
||||
params=params,
|
||||
data=data,
|
||||
files=files,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
multipart_parser=multipart_parser,
|
||||
retry_count=retry_count + 1,
|
||||
)
|
||||
final_error_message = (
|
||||
f"Request timed out after {self.timeout} seconds and {self.max_retries} retry attempts. "
|
||||
f"The server might be experiencing high load or the operation is taking longer than expected."
|
||||
)
|
||||
request_logger.log_request_response( # Log final failure
|
||||
operation_id=operation_id,
|
||||
request_method=method, request_url=url,
|
||||
error_message=final_error_message
|
||||
)
|
||||
raise Exception(final_error_message) from e
|
||||
|
||||
except requests.HTTPError as e:
|
||||
status_code = e.response.status_code if hasattr(e, "response") else None
|
||||
error_message = f"HTTP Error: {str(e)}"
|
||||
original_error_message = f"HTTP Error: {str(e)}"
|
||||
error_content_for_log = None
|
||||
if hasattr(e, "response") and e.response is not None:
|
||||
error_content_for_log = e.response.content
|
||||
try:
|
||||
error_content_for_log = e.response.json()
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
|
||||
# Try to extract detailed error message from JSON response for user display
|
||||
# but log the full error content.
|
||||
user_display_error_message = original_error_message
|
||||
|
||||
# Try to extract detailed error message from JSON response
|
||||
try:
|
||||
if hasattr(e, "response") and e.response.content:
|
||||
if hasattr(e, "response") and e.response is not None and e.response.content:
|
||||
error_json = e.response.json()
|
||||
if "error" in error_json and "message" in error_json["error"]:
|
||||
error_message = f"API Error: {error_json['error']['message']}"
|
||||
user_display_error_message = f"API Error: {error_json['error']['message']}"
|
||||
if "type" in error_json["error"]:
|
||||
error_message += f" (Type: {error_json['error']['type']})"
|
||||
user_display_error_message += f" (Type: {error_json['error']['type']})"
|
||||
elif isinstance(error_json, dict): # Handle cases where error is just a JSON dict
|
||||
user_display_error_message = f"API Error: {json.dumps(error_json)}"
|
||||
else: # Non-dict JSON error
|
||||
user_display_error_message = f"API Error: {str(error_json)}"
|
||||
except json.JSONDecodeError:
|
||||
# If not JSON, use the raw content if it's not too long, or a summary
|
||||
if hasattr(e, "response") and e.response is not None and e.response.content:
|
||||
raw_content = e.response.content.decode(errors='ignore')
|
||||
if len(raw_content) < 200: # Arbitrary limit for display
|
||||
user_display_error_message = f"API Error (raw): {raw_content}"
|
||||
else:
|
||||
error_message = f"API Error: {error_json}"
|
||||
except Exception as json_error:
|
||||
# If we can't parse the JSON, fall back to the original error message
|
||||
logging.debug(
|
||||
f"[DEBUG] Failed to parse error response: {str(json_error)}"
|
||||
user_display_error_message = f"API Error (raw, status {status_code})"
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method, request_url=url,
|
||||
response_status_code=status_code,
|
||||
response_headers=dict(e.response.headers) if hasattr(e, "response") and e.response is not None else None,
|
||||
response_content=error_content_for_log,
|
||||
error_message=original_error_message # Log the original exception string as error
|
||||
)
|
||||
|
||||
logging.debug(f"[DEBUG] API Error: {user_display_error_message} (Status: {status_code})")
|
||||
if hasattr(e, "response") and e.response is not None and e.response.content:
|
||||
logging.debug(f"[DEBUG] Response content: {e.response.content}")
|
||||
|
||||
# Retry if the status code is in our retry list and we haven't exhausted retries
|
||||
if (status_code in self.retry_status_codes and
|
||||
retry_count < self.max_retries):
|
||||
|
||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||||
logging.warning(
|
||||
f"HTTP error {status_code}. "
|
||||
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
|
||||
)
|
||||
time.sleep(delay)
|
||||
return self.request(
|
||||
method=method,
|
||||
path=path,
|
||||
params=params,
|
||||
data=data,
|
||||
files=files,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
multipart_parser=multipart_parser,
|
||||
retry_count=retry_count + 1,
|
||||
)
|
||||
|
||||
logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})")
|
||||
if hasattr(e, "response") and e.response.content:
|
||||
logging.debug(f"[DEBUG] Response content: {e.response.content}")
|
||||
# Specific error messages for common status codes for user display
|
||||
if status_code == 401:
|
||||
error_message = "Unauthorized: Please login first to use this node."
|
||||
if status_code == 402:
|
||||
error_message = "Payment Required: Please add credits to your account to use this node."
|
||||
if status_code == 409:
|
||||
error_message = "There is a problem with your account. Please contact support@comfy.org. "
|
||||
if status_code == 429:
|
||||
error_message = "Rate Limit Exceeded: Please try again later."
|
||||
raise Exception(error_message)
|
||||
user_display_error_message = "Unauthorized: Please login first to use this node."
|
||||
elif status_code == 402:
|
||||
user_display_error_message = "Payment Required: Please add credits to your account to use this node."
|
||||
elif status_code == 409:
|
||||
user_display_error_message = "There is a problem with your account. Please contact support@comfy.org."
|
||||
elif status_code == 429:
|
||||
user_display_error_message = "Rate Limit Exceeded: Please try again later."
|
||||
# else, user_display_error_message remains as parsed from response or original HTTPError string
|
||||
|
||||
raise Exception(user_display_error_message) # Raise with the user-friendly message
|
||||
|
||||
# Parse and return JSON response
|
||||
if response.content:
|
||||
return response.json()
|
||||
return {}
|
||||
|
||||
def check_auth_token(self, auth_token):
|
||||
"""Verify that an auth token is present."""
|
||||
if auth_token is None:
|
||||
def check_auth(self, auth_token, comfy_api_key):
|
||||
"""Verify that an auth token is present or comfy_api_key is present"""
|
||||
if auth_token is None and comfy_api_key is None:
|
||||
raise Exception("Unauthorized: Please login first to use this node.")
|
||||
return auth_token
|
||||
return auth_token or comfy_api_key
|
||||
|
||||
@staticmethod
|
||||
def upload_file(
|
||||
upload_url: str,
|
||||
file: io.BytesIO | str,
|
||||
content_type: str | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
):
|
||||
"""Upload a file to the API. Make sure the file has a filename equal to what the url expects.
|
||||
"""Upload a file to the API with retry logic.
|
||||
|
||||
Args:
|
||||
upload_url: The URL to upload to
|
||||
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
|
||||
mime_type: Optional mime type to set for the upload
|
||||
content_type: Optional mime type to set for the upload
|
||||
max_retries: Maximum number of retry attempts
|
||||
retry_delay: Initial delay between retries in seconds
|
||||
retry_backoff_factor: Multiplier for the delay after each retry
|
||||
"""
|
||||
headers = {}
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
|
||||
# Prepare the file data
|
||||
if isinstance(file, io.BytesIO):
|
||||
file.seek(0) # Ensure we're at the start of the file
|
||||
data = file.read()
|
||||
return requests.put(upload_url, data=data, headers=headers)
|
||||
elif isinstance(file, str):
|
||||
with open(file, "rb") as f:
|
||||
data = f.read()
|
||||
return requests.put(upload_url, data=data, headers=headers)
|
||||
else:
|
||||
raise ValueError("File must be either a BytesIO object or a file path string")
|
||||
|
||||
# Try the upload with retries
|
||||
last_exception = None
|
||||
operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" # Simplified ID for uploads
|
||||
|
||||
# Log initial attempt (without full file data for brevity)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers,
|
||||
request_data=f"[File data of type {content_type or 'unknown'}, size {len(data)} bytes]"
|
||||
)
|
||||
|
||||
for retry_attempt in range(max_retries + 1):
|
||||
try:
|
||||
response = requests.put(upload_url, data=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT", request_url=upload_url, # For context
|
||||
response_status_code=response.status_code,
|
||||
response_headers=dict(response.headers),
|
||||
response_content="File uploaded successfully." # Or response.text if available
|
||||
)
|
||||
return response
|
||||
|
||||
except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e:
|
||||
last_exception = e
|
||||
error_message_for_log = f"{type(e).__name__}: {str(e)}"
|
||||
response_content_for_log = None
|
||||
status_code_for_log = None
|
||||
headers_for_log = None
|
||||
|
||||
if hasattr(e, 'response') and e.response is not None:
|
||||
status_code_for_log = e.response.status_code
|
||||
headers_for_log = dict(e.response.headers)
|
||||
try:
|
||||
response_content_for_log = e.response.json()
|
||||
except json.JSONDecodeError:
|
||||
response_content_for_log = e.response.content
|
||||
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT", request_url=upload_url,
|
||||
response_status_code=status_code_for_log,
|
||||
response_headers=headers_for_log,
|
||||
response_content=response_content_for_log,
|
||||
error_message=error_message_for_log
|
||||
)
|
||||
|
||||
if retry_attempt < max_retries:
|
||||
delay = retry_delay * (retry_backoff_factor ** retry_attempt)
|
||||
logging.warning(
|
||||
f"File upload failed: {str(e)}. "
|
||||
f"Retrying in {delay:.2f}s ({retry_attempt + 1}/{max_retries})"
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
break # Max retries reached
|
||||
|
||||
# If we've exhausted all retries, determine the final error type and raise
|
||||
final_error_message = f"Failed to upload file after {max_retries + 1} attempts. Error: {str(last_exception)}"
|
||||
try:
|
||||
# Check basic internet connectivity
|
||||
check_response = requests.get("https://www.google.com", timeout=5.0, verify=True) # Assuming verify=True is desired
|
||||
if check_response.status_code >= 500: # Google itself has an issue (rare)
|
||||
final_error_message = (f"Failed to upload file. Internet connectivity check to Google failed "
|
||||
f"(status {check_response.status_code}). Original error: {str(last_exception)}")
|
||||
# Not raising LocalNetworkError here as Google itself might be down.
|
||||
# If Google is reachable, the issue is likely with the upload server or a more specific local problem
|
||||
# not caught by a simple Google ping (e.g., DNS for the specific upload URL, firewall).
|
||||
# The original last_exception is probably most relevant.
|
||||
|
||||
except (requests.RequestException, socket.error) as conn_check_exc:
|
||||
# Could not reach Google, likely a local network issue
|
||||
final_error_message = (f"Failed to upload file due to network connectivity issues "
|
||||
f"(cannot reach Google: {str(conn_check_exc)}). "
|
||||
f"Original upload error: {str(last_exception)}")
|
||||
request_logger.log_request_response( # Log final failure reason
|
||||
operation_id=operation_id,
|
||||
request_method="PUT", request_url=upload_url,
|
||||
error_message=final_error_message
|
||||
)
|
||||
raise LocalNetworkError(final_error_message) from last_exception
|
||||
|
||||
request_logger.log_request_response( # Log final failure reason if not LocalNetworkError
|
||||
operation_id=operation_id,
|
||||
request_method="PUT", request_url=upload_url,
|
||||
error_message=final_error_message
|
||||
)
|
||||
raise Exception(final_error_message) from last_exception
|
||||
|
||||
|
||||
class ApiEndpoint(Generic[T, R]):
|
||||
@@ -392,10 +772,15 @@ class SynchronousOperation(Generic[T, R]):
|
||||
files: Optional[Dict[str, Any]] = None,
|
||||
api_base: str | None = None,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
auth_kwargs: Optional[Dict[str,str]] = None,
|
||||
timeout: float = 604800.0,
|
||||
verify_ssl: bool = True,
|
||||
content_type: str = "application/json",
|
||||
multipart_parser: Callable = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.request = request
|
||||
@@ -403,21 +788,33 @@ class SynchronousOperation(Generic[T, R]):
|
||||
self.error = None
|
||||
self.api_base: str = api_base or args.comfy_api_base
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
if auth_kwargs is not None:
|
||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
||||
self.timeout = timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
self.files = files
|
||||
self.content_type = content_type
|
||||
self.multipart_parser = multipart_parser
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
|
||||
def execute(self, client: Optional[ApiClient] = None) -> R:
|
||||
"""Execute the API operation using the provided client or create one"""
|
||||
"""Execute the API operation using the provided client or create one with retry support"""
|
||||
try:
|
||||
# Create client if not provided
|
||||
if client is None:
|
||||
client = ApiClient(
|
||||
base_url=self.api_base,
|
||||
api_key=self.auth_token,
|
||||
auth_token=self.auth_token,
|
||||
comfy_api_key=self.comfy_api_key,
|
||||
timeout=self.timeout,
|
||||
verify_ssl=self.verify_ssl,
|
||||
max_retries=self.max_retries,
|
||||
retry_delay=self.retry_delay,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
)
|
||||
|
||||
# Convert request model to dict, but use None for EmptyRequest
|
||||
@@ -431,11 +828,6 @@ class SynchronousOperation(Generic[T, R]):
|
||||
if isinstance(value, Enum):
|
||||
request_dict[key] = value.value
|
||||
|
||||
if request_dict:
|
||||
for key, value in request_dict.items():
|
||||
if isinstance(value, Enum):
|
||||
request_dict[key] = value.value
|
||||
|
||||
# Debug log for request
|
||||
logging.debug(
|
||||
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
|
||||
@@ -443,7 +835,7 @@ class SynchronousOperation(Generic[T, R]):
|
||||
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
|
||||
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
|
||||
|
||||
# Make the request
|
||||
# Make the request with built-in retry
|
||||
resp = client.request(
|
||||
method=self.endpoint.method.value,
|
||||
path=self.endpoint.path,
|
||||
@@ -464,8 +856,18 @@ class SynchronousOperation(Generic[T, R]):
|
||||
# Parse and return the response
|
||||
return self._parse_response(resp)
|
||||
|
||||
except LocalNetworkError as e:
|
||||
# Propagate specific network error types
|
||||
logging.error(f"[ERROR] Local network error: {str(e)}")
|
||||
raise
|
||||
|
||||
except ApiServerError as e:
|
||||
# Propagate API server errors
|
||||
logging.error(f"[ERROR] API server error: {str(e)}")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"[DEBUG] API Exception: {str(e)}")
|
||||
logging.error(f"[ERROR] API Exception: {str(e)}")
|
||||
raise Exception(str(e))
|
||||
|
||||
def _parse_response(self, resp):
|
||||
@@ -499,22 +901,42 @@ class PollingOperation(Generic[T, R]):
|
||||
failed_statuses: list,
|
||||
status_extractor: Callable[[R], str],
|
||||
progress_extractor: Callable[[R], float] = None,
|
||||
result_url_extractor: Callable[[R], str] = None,
|
||||
request: Optional[T] = None,
|
||||
api_base: str | None = None,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
auth_kwargs: Optional[Dict[str,str]] = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
|
||||
max_retries: int = 3, # Max retries per individual API call
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
estimated_duration: Optional[float] = None,
|
||||
node_id: Optional[str] = None,
|
||||
):
|
||||
self.poll_endpoint = poll_endpoint
|
||||
self.request = request
|
||||
self.api_base: str = api_base or args.comfy_api_base
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
if auth_kwargs is not None:
|
||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
||||
self.poll_interval = poll_interval
|
||||
self.max_poll_attempts = max_poll_attempts
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
self.estimated_duration = estimated_duration
|
||||
|
||||
# Polling configuration
|
||||
self.status_extractor = status_extractor or (
|
||||
lambda x: getattr(x, "status", None)
|
||||
)
|
||||
self.progress_extractor = progress_extractor
|
||||
self.result_url_extractor = result_url_extractor
|
||||
self.node_id = node_id
|
||||
self.completed_statuses = completed_statuses
|
||||
self.failed_statuses = failed_statuses
|
||||
|
||||
@@ -528,12 +950,48 @@ class PollingOperation(Generic[T, R]):
|
||||
if client is None:
|
||||
client = ApiClient(
|
||||
base_url=self.api_base,
|
||||
api_key=self.auth_token,
|
||||
auth_token=self.auth_token,
|
||||
comfy_api_key=self.comfy_api_key,
|
||||
max_retries=self.max_retries,
|
||||
retry_delay=self.retry_delay,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
)
|
||||
return self._poll_until_complete(client)
|
||||
except LocalNetworkError as e:
|
||||
# Provide clear message for local network issues
|
||||
raise Exception(
|
||||
f"Polling failed due to local network issues. Please check your internet connection. "
|
||||
f"Details: {str(e)}"
|
||||
) from e
|
||||
except ApiServerError as e:
|
||||
# Provide clear message for API server issues
|
||||
raise Exception(
|
||||
f"Polling failed due to API server issues. The service may be experiencing problems. "
|
||||
f"Please try again later. Details: {str(e)}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise Exception(f"Error during polling: {str(e)}")
|
||||
|
||||
def _display_text_on_node(self, text: str):
|
||||
"""Sends text to the client which will be displayed on the node in the UI"""
|
||||
if not self.node_id:
|
||||
return
|
||||
|
||||
PromptServer.instance.send_progress_text(text, self.node_id)
|
||||
|
||||
def _display_time_progress_on_node(self, time_completed: int):
|
||||
if not self.node_id:
|
||||
return
|
||||
|
||||
if self.estimated_duration is not None:
|
||||
estimated_time_remaining = max(
|
||||
0, int(self.estimated_duration) - int(time_completed)
|
||||
)
|
||||
message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)"
|
||||
else:
|
||||
message = f"Task in progress: {time_completed:.0f}s"
|
||||
self._display_text_on_node(message)
|
||||
|
||||
def _check_task_status(self, response: R) -> TaskStatus:
|
||||
"""Check task status using the status extractor function"""
|
||||
try:
|
||||
@@ -550,10 +1008,13 @@ class PollingOperation(Generic[T, R]):
|
||||
def _poll_until_complete(self, client: ApiClient) -> R:
|
||||
"""Poll until the task is complete"""
|
||||
poll_count = 0
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
|
||||
|
||||
if self.progress_extractor:
|
||||
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
|
||||
|
||||
while True:
|
||||
while poll_count < self.max_poll_attempts:
|
||||
try:
|
||||
poll_count += 1
|
||||
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
|
||||
@@ -580,8 +1041,12 @@ class PollingOperation(Generic[T, R]):
|
||||
data=request_dict,
|
||||
)
|
||||
|
||||
# Successfully got a response, reset consecutive error count
|
||||
consecutive_errors = 0
|
||||
|
||||
# Parse response
|
||||
response_obj = self.poll_endpoint.response_model.model_validate(resp)
|
||||
|
||||
# Check if task is complete
|
||||
status = self._check_task_status(response_obj)
|
||||
logging.debug(f"[DEBUG] Task Status: {status}")
|
||||
@@ -593,7 +1058,15 @@ class PollingOperation(Generic[T, R]):
|
||||
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
|
||||
|
||||
if status == TaskStatus.COMPLETED:
|
||||
logging.debug("[DEBUG] Task completed successfully")
|
||||
message = "Task completed successfully"
|
||||
if self.result_url_extractor:
|
||||
result_url = self.result_url_extractor(response_obj)
|
||||
if result_url:
|
||||
message = f"Result URL: {result_url}"
|
||||
else:
|
||||
message = "Task completed successfully!"
|
||||
logging.debug(f"[DEBUG] {message}")
|
||||
self._display_text_on_node(message)
|
||||
self.final_response = response_obj
|
||||
if self.progress_extractor:
|
||||
progress.update(100)
|
||||
@@ -609,8 +1082,43 @@ class PollingOperation(Generic[T, R]):
|
||||
logging.debug(
|
||||
f"[DEBUG] Waiting {self.poll_interval} seconds before next poll"
|
||||
)
|
||||
for i in range(int(self.poll_interval)):
|
||||
time_completed = (poll_count * self.poll_interval) + i
|
||||
self._display_time_progress_on_node(time_completed)
|
||||
time.sleep(1)
|
||||
|
||||
except (LocalNetworkError, ApiServerError) as e:
|
||||
# For network-related errors, increment error count and potentially abort
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
raise Exception(
|
||||
f"Polling aborted after {consecutive_errors} consecutive network errors: {str(e)}"
|
||||
) from e
|
||||
|
||||
# Log the error but continue polling
|
||||
logging.warning(
|
||||
f"Network error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
|
||||
f"Will retry in {self.poll_interval} seconds."
|
||||
)
|
||||
time.sleep(self.poll_interval)
|
||||
|
||||
except Exception as e:
|
||||
# For other errors, increment count and potentially abort
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
|
||||
raise Exception(
|
||||
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
|
||||
) from e
|
||||
|
||||
logging.error(f"[DEBUG] Polling error: {str(e)}")
|
||||
raise Exception(f"Error while polling: {str(e)}")
|
||||
logging.warning(
|
||||
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
|
||||
f"Will retry in {self.poll_interval} seconds."
|
||||
)
|
||||
time.sleep(self.poll_interval)
|
||||
|
||||
# If we've exhausted all polling attempts
|
||||
raise Exception(
|
||||
f"Polling timed out after {poll_count} attempts ({poll_count * self.poll_interval} seconds). "
|
||||
f"The operation may still be running on the server but is taking longer than expected."
|
||||
)
|
||||
|
||||
@@ -81,7 +81,6 @@ class RecraftStyle:
|
||||
|
||||
class RecraftIO:
|
||||
STYLEV3 = "RECRAFT_V3_STYLE"
|
||||
SVG = "SVG" # TODO: if acceptable, move into ComfyUI's typing class
|
||||
COLOR = "RECRAFT_COLOR"
|
||||
CONTROLS = "RECRAFT_CONTROLS"
|
||||
|
||||
|
||||
125
comfy_api_nodes/apis/request_logger.py
Normal file
125
comfy_api_nodes/apis/request_logger.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import os
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import folder_paths
|
||||
|
||||
# Get the logger instance
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_log_directory():
|
||||
"""
|
||||
Ensures the API log directory exists within ComfyUI's temp directory
|
||||
and returns its path.
|
||||
"""
|
||||
base_temp_dir = folder_paths.get_temp_directory()
|
||||
log_dir = os.path.join(base_temp_dir, "api_logs")
|
||||
try:
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating API log directory {log_dir}: {e}")
|
||||
# Fallback to base temp directory if sub-directory creation fails
|
||||
return base_temp_dir
|
||||
return log_dir
|
||||
|
||||
def _format_data_for_logging(data):
|
||||
"""Helper to format data (dict, str, bytes) for logging."""
|
||||
if isinstance(data, bytes):
|
||||
try:
|
||||
return data.decode('utf-8') # Try to decode as text
|
||||
except UnicodeDecodeError:
|
||||
return f"[Binary data of length {len(data)} bytes]"
|
||||
elif isinstance(data, (dict, list)):
|
||||
try:
|
||||
return json.dumps(data, indent=2, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(data) # Fallback for non-serializable objects
|
||||
return str(data)
|
||||
|
||||
def log_request_response(
|
||||
operation_id: str,
|
||||
request_method: str,
|
||||
request_url: str,
|
||||
request_headers: dict | None = None,
|
||||
request_params: dict | None = None,
|
||||
request_data: any = None,
|
||||
response_status_code: int | None = None,
|
||||
response_headers: dict | None = None,
|
||||
response_content: any = None,
|
||||
error_message: str | None = None
|
||||
):
|
||||
"""
|
||||
Logs API request and response details to a file in the temp/api_logs directory.
|
||||
"""
|
||||
log_dir = get_log_directory()
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log"
|
||||
filepath = os.path.join(log_dir, filename)
|
||||
|
||||
log_content = []
|
||||
|
||||
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
|
||||
log_content.append(f"Operation ID: {operation_id}")
|
||||
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
|
||||
log_content.append(f"Method: {request_method}")
|
||||
log_content.append(f"URL: {request_url}")
|
||||
if request_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
|
||||
if request_params:
|
||||
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
|
||||
if request_data:
|
||||
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
|
||||
|
||||
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
|
||||
if response_status_code is not None:
|
||||
log_content.append(f"Status Code: {response_status_code}")
|
||||
if response_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
|
||||
if response_content:
|
||||
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
|
||||
if error_message:
|
||||
log_content.append(f"Error:\n{error_message}")
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_content))
|
||||
logger.debug(f"API log saved to: {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing API log to {filepath}: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Example usage (for testing the logger directly)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
# Mock folder_paths for direct execution if not running within ComfyUI full context
|
||||
if not hasattr(folder_paths, 'get_temp_directory'):
|
||||
class MockFolderPaths:
|
||||
def get_temp_directory(self):
|
||||
# Create a local temp dir for testing if needed
|
||||
p = os.path.join(os.path.dirname(__file__), 'temp_test_logs')
|
||||
os.makedirs(p, exist_ok=True)
|
||||
return p
|
||||
folder_paths = MockFolderPaths()
|
||||
|
||||
log_request_response(
|
||||
operation_id="test_operation_get",
|
||||
request_method="GET",
|
||||
request_url="https://api.example.com/test",
|
||||
request_headers={"Authorization": "Bearer testtoken"},
|
||||
request_params={"param1": "value1"},
|
||||
response_status_code=200,
|
||||
response_content={"message": "Success!"}
|
||||
)
|
||||
log_request_response(
|
||||
operation_id="test_operation_post_error",
|
||||
request_method="POST",
|
||||
request_url="https://api.example.com/submit",
|
||||
request_data={"key": "value", "nested": {"num": 123}},
|
||||
error_message="Connection timed out"
|
||||
)
|
||||
log_request_response(
|
||||
operation_id="test_binary_response",
|
||||
request_method="GET",
|
||||
request_url="https://api.example.com/image.png",
|
||||
response_status_code=200,
|
||||
response_content=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR...' # Sample binary data
|
||||
)
|
||||
10
comfy_api_nodes/canary.py
Normal file
10
comfy_api_nodes/canary.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import av
|
||||
|
||||
ver = av.__version__.split(".")
|
||||
if int(ver[0]) < 14:
|
||||
raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
|
||||
|
||||
if int(ver[0]) == 14 and int(ver[1]) < 2:
|
||||
raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
@@ -1,5 +1,6 @@
|
||||
import io
|
||||
from inspect import cleandoc
|
||||
from typing import Union
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api_nodes.apis.bfl_api import (
|
||||
BFLStatus,
|
||||
@@ -30,6 +31,7 @@ import requests
|
||||
import torch
|
||||
import base64
|
||||
import time
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
def convert_mask_to_image(mask: torch.Tensor):
|
||||
@@ -42,14 +44,19 @@ def convert_mask_to_image(mask: torch.Tensor):
|
||||
|
||||
|
||||
def handle_bfl_synchronous_operation(
|
||||
operation: SynchronousOperation, timeout_bfl_calls=360
|
||||
operation: SynchronousOperation,
|
||||
timeout_bfl_calls=360,
|
||||
node_id: Union[str, None] = None,
|
||||
):
|
||||
response_api: BFLFluxProGenerateResponse = operation.execute()
|
||||
return _poll_until_generated(
|
||||
response_api.polling_url, timeout=timeout_bfl_calls
|
||||
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
|
||||
)
|
||||
|
||||
def _poll_until_generated(polling_url: str, timeout=360):
|
||||
|
||||
def _poll_until_generated(
|
||||
polling_url: str, timeout=360, node_id: Union[str, None] = None
|
||||
):
|
||||
# used bfl-comfy-nodes to verify code implementation:
|
||||
# https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
|
||||
start_time = time.time()
|
||||
@@ -61,11 +68,21 @@ def _poll_until_generated(polling_url: str, timeout=360):
|
||||
request = requests.Request(method=HttpMethod.GET, url=polling_url)
|
||||
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
|
||||
while True:
|
||||
if node_id:
|
||||
time_elapsed = time.time() - start_time
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generating ({time_elapsed:.0f}s)", node_id
|
||||
)
|
||||
|
||||
response = requests.Session().send(request.prepare())
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result["status"] == BFLStatus.ready:
|
||||
img_url = result["result"]["sample"]
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {img_url}", node_id
|
||||
)
|
||||
img_response = requests.get(img_url)
|
||||
return process_image_response(img_response)
|
||||
elif result["status"] in [
|
||||
@@ -179,6 +196,8 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -211,7 +230,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
||||
seed=0,
|
||||
image_prompt=None,
|
||||
image_prompt_strength=0.1,
|
||||
auth_token=None,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if image_prompt is None:
|
||||
@@ -244,9 +263,9 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
||||
None if image_prompt is None else round(image_prompt_strength, 2)
|
||||
),
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -319,6 +338,8 @@ class FluxProImageNode(ComfyNodeABC):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -337,7 +358,7 @@ class FluxProImageNode(ComfyNodeABC):
|
||||
seed=0,
|
||||
image_prompt=None,
|
||||
# image_prompt_strength=0.1,
|
||||
auth_token=None,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
image_prompt = (
|
||||
@@ -361,9 +382,9 @@ class FluxProImageNode(ComfyNodeABC):
|
||||
seed=seed,
|
||||
image_prompt=image_prompt,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -457,10 +478,11 @@ class FluxProExpandNode(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -482,7 +504,7 @@ class FluxProExpandNode(ComfyNodeABC):
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
auth_token=None,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
image = convert_image_to_base64(image)
|
||||
@@ -506,9 +528,9 @@ class FluxProExpandNode(ComfyNodeABC):
|
||||
seed=seed,
|
||||
image=image,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -568,10 +590,11 @@ class FluxProFillNode(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -590,14 +613,14 @@ class FluxProFillNode(ComfyNodeABC):
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
auth_token=None,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# prepare mask
|
||||
mask = resize_mask_to_image(mask, image)
|
||||
mask = convert_image_to_base64(convert_mask_to_image(mask))
|
||||
# make sure image will have alpha channel removed
|
||||
image = convert_image_to_base64(image[:,:,:,:3])
|
||||
image = convert_image_to_base64(image[:, :, :, :3])
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -615,9 +638,9 @@ class FluxProFillNode(ComfyNodeABC):
|
||||
image=image,
|
||||
mask=mask,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -702,10 +725,11 @@ class FluxProCannyNode(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -726,10 +750,10 @@ class FluxProCannyNode(ComfyNodeABC):
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
auth_token=None,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
control_image = convert_image_to_base64(control_image[:,:,:,:3])
|
||||
control_image = convert_image_to_base64(control_image[:, :, :, :3])
|
||||
preprocessed_image = None
|
||||
|
||||
# scale canny threshold between 0-500, to match BFL's API
|
||||
@@ -763,9 +787,9 @@ class FluxProCannyNode(ComfyNodeABC):
|
||||
canny_high_threshold=canny_high_threshold,
|
||||
preprocessed_image=preprocessed_image,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -830,10 +854,11 @@ class FluxProDepthNode(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -852,7 +877,7 @@ class FluxProDepthNode(ComfyNodeABC):
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
auth_token=None,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
control_image = convert_image_to_base64(control_image[:,:,:,:3])
|
||||
@@ -878,9 +903,9 @@ class FluxProDepthNode(ComfyNodeABC):
|
||||
control_image=control_image,
|
||||
preprocessed_image=preprocessed_image,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from comfy_api_nodes.apinode_utils import (
|
||||
bytesio_to_image_tensor,
|
||||
resize_mask_to_image,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
V1_V1_RES_MAP = {
|
||||
"Auto":"AUTO",
|
||||
@@ -232,11 +233,22 @@ def download_and_process_images(image_urls):
|
||||
return stacked_tensors
|
||||
|
||||
|
||||
def display_image_urls_on_node(image_urls, node_id):
|
||||
if node_id and image_urls:
|
||||
if len(image_urls) == 1:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generated Image URL:\n{image_urls[0]}", node_id
|
||||
)
|
||||
else:
|
||||
urls_text = "Generated Image URLs:\n" + "\n".join(
|
||||
f"{i+1}. {url}" for i, url in enumerate(image_urls)
|
||||
)
|
||||
PromptServer.instance.send_progress_text(urls_text, node_id)
|
||||
|
||||
|
||||
class IdeogramV1(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously using the Ideogram V1 model.
|
||||
|
||||
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
|
||||
Generates images using the Ideogram V1 model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -303,7 +315,11 @@ class IdeogramV1(ComfyNodeABC):
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
@@ -321,7 +337,8 @@ class IdeogramV1(ComfyNodeABC):
|
||||
seed=0,
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
auth_token=None,
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Determine the model based on turbo setting
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
@@ -347,7 +364,7 @@ class IdeogramV1(ComfyNodeABC):
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
)
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
@@ -360,14 +377,13 @@ class IdeogramV1(ComfyNodeABC):
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
|
||||
|
||||
class IdeogramV2(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously using the Ideogram V2 model.
|
||||
|
||||
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
|
||||
Generates images using the Ideogram V2 model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -458,7 +474,11 @@ class IdeogramV2(ComfyNodeABC):
|
||||
# },
|
||||
#),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
@@ -479,7 +499,8 @@ class IdeogramV2(ComfyNodeABC):
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
color_palette="",
|
||||
auth_token=None,
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
resolution = V1_V1_RES_MAP.get(resolution, None)
|
||||
@@ -519,7 +540,7 @@ class IdeogramV2(ComfyNodeABC):
|
||||
color_palette=color_palette if color_palette else None,
|
||||
)
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
@@ -532,14 +553,12 @@ class IdeogramV2(ComfyNodeABC):
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
|
||||
class IdeogramV3(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously using the Ideogram V3 model.
|
||||
|
||||
Supports both regular image generation from text prompts and image editing with mask.
|
||||
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
|
||||
Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -621,7 +640,11 @@ class IdeogramV3(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
@@ -641,7 +664,8 @@ class IdeogramV3(ComfyNodeABC):
|
||||
seed=0,
|
||||
num_images=1,
|
||||
rendering_speed="BALANCED",
|
||||
auth_token=None,
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Check if both image and mask are provided for editing mode
|
||||
if image is not None and mask is not None:
|
||||
@@ -705,7 +729,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
"mask": mask_binary,
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
elif image is not None or mask is not None:
|
||||
@@ -746,7 +770,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=gen_request,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Execute the operation and process response
|
||||
@@ -760,6 +784,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
|
||||
|
||||
@@ -774,4 +799,3 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"IdeogramV2": "Ideogram V2",
|
||||
"IdeogramV3": "Ideogram V3",
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ For source of truth on the allowed permutations of request fields, please refere
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional, TypeVar, Any
|
||||
from collections.abc import Callable
|
||||
import math
|
||||
import logging
|
||||
|
||||
@@ -64,6 +65,12 @@ from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_image_tensor,
|
||||
)
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
from comfy_api_nodes.util.validation_utils import (
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio,
|
||||
validate_video_dimensions,
|
||||
validate_video_duration,
|
||||
)
|
||||
from comfy_api.input.basic_types import AudioInput
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
@@ -79,13 +86,20 @@ PATH_CHARACTER_IMAGE = f"/proxy/kling/{KLING_API_VERSION}/images/generations"
|
||||
PATH_VIRTUAL_TRY_ON = f"/proxy/kling/{KLING_API_VERSION}/images/kolors-virtual-try-on"
|
||||
PATH_IMAGE_GENERATIONS = f"/proxy/kling/{KLING_API_VERSION}/images/generations"
|
||||
|
||||
|
||||
MAX_PROMPT_LENGTH_T2V = 2500
|
||||
MAX_PROMPT_LENGTH_I2V = 500
|
||||
MAX_PROMPT_LENGTH_IMAGE_GEN = 500
|
||||
MAX_NEGATIVE_PROMPT_LENGTH_IMAGE_GEN = 200
|
||||
MAX_PROMPT_LENGTH_LIP_SYNC = 120
|
||||
|
||||
AVERAGE_DURATION_T2V = 319
|
||||
AVERAGE_DURATION_I2V = 164
|
||||
AVERAGE_DURATION_LIP_SYNC = 455
|
||||
AVERAGE_DURATION_VIRTUAL_TRY_ON = 19
|
||||
AVERAGE_DURATION_IMAGE_GEN = 32
|
||||
AVERAGE_DURATION_VIDEO_EFFECTS = 320
|
||||
AVERAGE_DURATION_VIDEO_EXTEND = 320
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
@@ -95,7 +109,13 @@ class KlingApiError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def poll_until_finished(auth_token: str, api_endpoint: ApiEndpoint[Any, R]) -> R:
|
||||
def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
@@ -108,7 +128,10 @@ def poll_until_finished(auth_token: str, api_endpoint: ApiEndpoint[Any, R]) -> R
|
||||
if response.data and response.data.task_status
|
||||
else None
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=result_url_extractor,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
).execute()
|
||||
|
||||
|
||||
@@ -184,6 +207,18 @@ def validate_image_result_response(response) -> None:
|
||||
raise KlingApiError(error_msg)
|
||||
|
||||
|
||||
def validate_input_image(image: torch.Tensor) -> None:
|
||||
"""
|
||||
Validates the input image adheres to the expectations of the Kling API:
|
||||
- The image resolution should not be less than 300*300px
|
||||
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
|
||||
|
||||
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
|
||||
"""
|
||||
validate_image_dimensions(image, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5)
|
||||
|
||||
|
||||
def get_camera_control_input_config(
|
||||
tooltip: str, default: float = 0.0
|
||||
) -> tuple[IO, InputTypeOptions]:
|
||||
@@ -200,7 +235,9 @@ def get_camera_control_input_config(
|
||||
|
||||
|
||||
def get_video_from_response(response) -> KlingVideoResult:
|
||||
"""Returns the first video object from the Kling video generation task result."""
|
||||
"""Returns the first video object from the Kling video generation task result.
|
||||
Will raise an error if the response is not valid.
|
||||
"""
|
||||
video = response.data.task_result.videos[0]
|
||||
logging.info(
|
||||
"Kling task %s succeeded. Video URL: %s", response.data.task_id, video.url
|
||||
@@ -208,12 +245,37 @@ def get_video_from_response(response) -> KlingVideoResult:
|
||||
return video
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
"""Returns the first video url from the Kling video generation task result.
|
||||
Will not raise an error if the response is not valid.
|
||||
"""
|
||||
if response and is_valid_video_response(response):
|
||||
return str(get_video_from_response(response).url)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_images_from_response(response) -> list[KlingImageResult]:
|
||||
"""Returns the list of image objects from the Kling image generation task result.
|
||||
Will raise an error if the response is not valid.
|
||||
"""
|
||||
images = response.data.task_result.images
|
||||
logging.info("Kling task %s succeeded. Images: %s", response.data.task_id, images)
|
||||
return images
|
||||
|
||||
|
||||
def get_images_urls_from_response(response) -> Optional[str]:
|
||||
"""Returns the list of image urls from the Kling image generation task result.
|
||||
Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls.
|
||||
"""
|
||||
if response and is_valid_image_response(response):
|
||||
images = get_images_from_response(response)
|
||||
image_urls = [str(image.url) for image in images]
|
||||
return "\n".join(image_urls)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def video_result_to_node_output(
|
||||
video: KlingVideoResult,
|
||||
) -> tuple[VideoFromFile, str, str]:
|
||||
@@ -285,6 +347,7 @@ class KlingCameraControls(KlingNodeBase):
|
||||
RETURN_TYPES = ("CAMERA_CONTROL",)
|
||||
RETURN_NAMES = ("camera_control",)
|
||||
FUNCTION = "main"
|
||||
API_NODE = False # This is just a helper node, it doesn't make an API call
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(
|
||||
@@ -391,22 +454,31 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
DESCRIPTION = "Kling Text to Video Node"
|
||||
|
||||
def get_response(self, task_id: str, auth_token: str) -> KlingText2VideoResponse:
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingText2VideoResponse:
|
||||
return poll_until_finished(
|
||||
auth_token,
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_TEXT_TO_VIDEO}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingText2VideoResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
@@ -419,7 +491,8 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
model_name: Optional[str] = None,
|
||||
duration: Optional[str] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile, str, str]:
|
||||
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
|
||||
if model_name is None:
|
||||
@@ -441,14 +514,16 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
|
||||
camera_control=camera_control,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
|
||||
task_id = task_creation_response.data.task_id
|
||||
final_response = self.get_response(task_id, auth_token)
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
@@ -495,7 +570,11 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text."
|
||||
@@ -507,7 +586,8 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||
cfg_scale: float,
|
||||
aspect_ratio: str,
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
model_name=KlingVideoGenModelName.kling_v1,
|
||||
@@ -518,7 +598,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
camera_control=camera_control,
|
||||
auth_token=auth_token,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -530,7 +610,10 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
return {
|
||||
"required": {
|
||||
"start_frame": model_field_to_node_input(
|
||||
IO.IMAGE, KlingImage2VideoRequest, "image"
|
||||
IO.IMAGE,
|
||||
KlingImage2VideoRequest,
|
||||
"image",
|
||||
tooltip="The reference image used to generate the video.",
|
||||
),
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True
|
||||
@@ -574,22 +657,31 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
enum_type=KlingVideoGenDuration,
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
DESCRIPTION = "Kling Image to Video Node"
|
||||
|
||||
def get_response(self, task_id: str, auth_token: str) -> KlingImage2VideoResponse:
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingImage2VideoResponse:
|
||||
return poll_until_finished(
|
||||
auth_token,
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=KlingImage2VideoRequest,
|
||||
response_model=KlingImage2VideoResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
@@ -604,12 +696,14 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
duration: str,
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
end_frame: Optional[torch.Tensor] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V)
|
||||
validate_input_image(start_frame)
|
||||
|
||||
if camera_control is not None:
|
||||
# Camera control type for image 2 video is always simple
|
||||
# Camera control type for image 2 video is always `simple`
|
||||
camera_control.type = KlingCameraControlType.simple
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
@@ -631,18 +725,19 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
cfg_scale=cfg_scale,
|
||||
mode=KlingVideoGenMode(mode),
|
||||
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
|
||||
duration=KlingVideoGenDuration(duration),
|
||||
camera_control=camera_control,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(task_id, auth_token)
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
@@ -692,7 +787,11 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image."
|
||||
@@ -705,7 +804,8 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||
cfg_scale: float,
|
||||
aspect_ratio: str,
|
||||
camera_control: KlingCameraControl,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
model_name=KlingVideoGenModelName.kling_v1_5,
|
||||
@@ -717,7 +817,8 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
camera_control=camera_control,
|
||||
auth_token=auth_token,
|
||||
unique_id=unique_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -785,7 +886,11 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last."
|
||||
@@ -799,7 +904,8 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
|
||||
cfg_scale: float,
|
||||
aspect_ratio: str,
|
||||
mode: str,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[
|
||||
mode
|
||||
@@ -814,7 +920,8 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=duration,
|
||||
end_frame=end_frame,
|
||||
auth_token=auth_token,
|
||||
unique_id=unique_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -844,22 +951,31 @@ class KlingVideoExtendNode(KlingNodeBase):
|
||||
IO.STRING, KlingVideoExtendRequest, "video_id", forceInput=True
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes."
|
||||
|
||||
def get_response(self, task_id: str, auth_token: str) -> KlingVideoExtendResponse:
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingVideoExtendResponse:
|
||||
return poll_until_finished(
|
||||
auth_token,
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_EXTEND}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingVideoExtendResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
@@ -868,7 +984,8 @@ class KlingVideoExtendNode(KlingNodeBase):
|
||||
negative_prompt: str,
|
||||
cfg_scale: float,
|
||||
video_id: str,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile, str, str]:
|
||||
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
|
||||
initial_operation = SynchronousOperation(
|
||||
@@ -884,14 +1001,16 @@ class KlingVideoExtendNode(KlingNodeBase):
|
||||
cfg_scale=cfg_scale,
|
||||
video_id=video_id,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(task_id, auth_token)
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
@@ -904,15 +1023,20 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
||||
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
|
||||
def get_response(self, task_id: str, auth_token: str) -> KlingVideoEffectsResponse:
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingVideoEffectsResponse:
|
||||
return poll_until_finished(
|
||||
auth_token,
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_EFFECTS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingVideoEffectsResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
@@ -924,7 +1048,8 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
||||
image_1: torch.Tensor,
|
||||
image_2: Optional[torch.Tensor] = None,
|
||||
mode: Optional[KlingVideoGenMode] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if dual_character:
|
||||
request_input_field = KlingDualCharacterEffectInput(
|
||||
@@ -954,14 +1079,16 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
||||
effect_scene=effect_scene,
|
||||
input=request_input_field,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(task_id, auth_token)
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
@@ -1002,7 +1129,11 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||
enum_type=KlingVideoGenDuration,
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite."
|
||||
@@ -1017,7 +1148,8 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||
model_name: KlingCharacterEffectModelName,
|
||||
mode: KlingVideoGenMode,
|
||||
duration: KlingVideoGenDuration,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
video, _, duration = super().api_call(
|
||||
dual_character=True,
|
||||
@@ -1027,10 +1159,12 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||
duration=duration,
|
||||
image_1=image_left,
|
||||
image_2=image_right,
|
||||
auth_token=auth_token,
|
||||
unique_id=unique_id,
|
||||
**kwargs,
|
||||
)
|
||||
return video, duration
|
||||
|
||||
|
||||
class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
||||
"""Kling Single Image Video Effect Node"""
|
||||
|
||||
@@ -1063,7 +1197,11 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
||||
enum_type=KlingVideoGenDuration,
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene."
|
||||
@@ -1074,7 +1212,8 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
||||
effect_scene: KlingSingleImageEffectsScene,
|
||||
model_name: KlingSingleImageEffectModelName,
|
||||
duration: KlingVideoGenDuration,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
dual_character=False,
|
||||
@@ -1082,7 +1221,8 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
||||
model_name=model_name,
|
||||
duration=duration,
|
||||
image_1=image,
|
||||
auth_token=auth_token,
|
||||
unique_id=unique_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1092,6 +1232,17 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
|
||||
def validate_lip_sync_video(self, video: VideoInput):
|
||||
"""
|
||||
Validates the input video adheres to the expectations of the Kling Lip Sync API:
|
||||
- Video length does not exceed 10s and is not shorter than 2s
|
||||
- Length and width dimensions should both be between 720px and 1920px
|
||||
|
||||
See: https://app.klingai.com/global/dev/document-api/apiReference/model/videoTolip
|
||||
"""
|
||||
validate_video_dimensions(video, 720, 1920)
|
||||
validate_video_duration(video, 2, 10)
|
||||
|
||||
def validate_text(self, text: str):
|
||||
if not text:
|
||||
raise ValueError("Text is required")
|
||||
@@ -1100,16 +1251,21 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters."
|
||||
)
|
||||
|
||||
def get_response(self, task_id: str, auth_token: str) -> KlingLipSyncResponse:
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingLipSyncResponse:
|
||||
"""Polls the Kling API endpoint until the task reaches a terminal state."""
|
||||
return poll_until_finished(
|
||||
auth_token,
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_LIP_SYNC}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingLipSyncResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_LIP_SYNC,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
@@ -1121,18 +1277,20 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
text: Optional[str] = None,
|
||||
voice_speed: Optional[float] = None,
|
||||
voice_id: Optional[str] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile, str, str]:
|
||||
if text:
|
||||
self.validate_text(text)
|
||||
self.validate_lip_sync_video(video)
|
||||
|
||||
# Upload video to Comfy API and get download URL
|
||||
video_url = upload_video_to_comfyapi(video, auth_token)
|
||||
video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs)
|
||||
logging.info("Uploaded video to Comfy API. URL: %s", video_url)
|
||||
|
||||
# Upload the audio file to Comfy API and get download URL
|
||||
if audio:
|
||||
audio_url = upload_audio_to_comfyapi(audio, auth_token)
|
||||
audio_url = upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
|
||||
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
|
||||
else:
|
||||
audio_url = None
|
||||
@@ -1156,14 +1314,16 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
voice_id=voice_id,
|
||||
),
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(task_id, auth_token)
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
@@ -1186,24 +1346,30 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
|
||||
enum_type=KlingLipSyncVoiceLanguage,
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file."
|
||||
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
audio: AudioInput,
|
||||
voice_language: str,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
video=video,
|
||||
audio=audio,
|
||||
voice_language=voice_language,
|
||||
mode="audio2video",
|
||||
auth_token=auth_token,
|
||||
unique_id=unique_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1292,10 +1458,14 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
|
||||
IO.FLOAT, KlingLipSyncInputObject, "voice_speed", slider=True
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt."
|
||||
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
@@ -1303,7 +1473,8 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
|
||||
text: str,
|
||||
voice: str,
|
||||
voice_speed: float,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice]
|
||||
return super().api_call(
|
||||
@@ -1313,7 +1484,8 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
|
||||
voice_id=voice_id,
|
||||
voice_speed=voice_speed,
|
||||
mode="text2video",
|
||||
auth_token=auth_token,
|
||||
unique_id=unique_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1350,22 +1522,29 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
||||
enum_type=KlingVirtualTryOnModelName,
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human."
|
||||
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background."
|
||||
|
||||
def get_response(
|
||||
self, task_id: str, auth_token: Optional[str] = None
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingVirtualTryOnResponse:
|
||||
return poll_until_finished(
|
||||
auth_token,
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingVirtualTryOnResponse,
|
||||
),
|
||||
result_url_extractor=get_images_urls_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
@@ -1373,7 +1552,8 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
||||
human_image: torch.Tensor,
|
||||
cloth_image: torch.Tensor,
|
||||
model_name: KlingVirtualTryOnModelName,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -1387,14 +1567,16 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
||||
cloth_image=tensor_to_base64_string(cloth_image),
|
||||
model_name=model_name,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(task_id, auth_token)
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_image_result_response(final_response)
|
||||
|
||||
images = get_images_from_response(final_response)
|
||||
@@ -1462,22 +1644,32 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
"optional": {
|
||||
"image": (IO.IMAGE, {}),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image."
|
||||
|
||||
def get_response(
|
||||
self, task_id: str, auth_token: Optional[str] = None
|
||||
self,
|
||||
task_id: str,
|
||||
auth_kwargs: Optional[dict[str, str]],
|
||||
node_id: Optional[str] = None,
|
||||
) -> KlingImageGenerationsResponse:
|
||||
return poll_until_finished(
|
||||
auth_token,
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_IMAGE_GENERATIONS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingImageGenerationsResponse,
|
||||
),
|
||||
result_url_extractor=get_images_urls_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_IMAGE_GEN,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
@@ -1491,7 +1683,8 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
n: int,
|
||||
aspect_ratio: KlingImageGenAspectRatio,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.validate_prompt(prompt, negative_prompt)
|
||||
|
||||
@@ -1516,14 +1709,16 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
n=n,
|
||||
aspect_ratio=aspect_ratio,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(task_id, auth_token)
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_image_result_response(final_response)
|
||||
|
||||
images = get_images_from_response(final_response)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis.luma_api import (
|
||||
@@ -34,11 +36,20 @@ from comfy_api_nodes.apinode_utils import (
|
||||
process_image_response,
|
||||
validate_string,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from io import BytesIO
|
||||
|
||||
LUMA_T2V_AVERAGE_DURATION = 105
|
||||
LUMA_I2V_AVERAGE_DURATION = 100
|
||||
|
||||
def image_result_url_extractor(response: LumaGeneration):
|
||||
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
|
||||
|
||||
def video_result_url_extractor(response: LumaGeneration):
|
||||
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
|
||||
|
||||
class LumaReferenceNode(ComfyNodeABC):
|
||||
"""
|
||||
@@ -201,6 +212,8 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -214,7 +227,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
image_luma_ref: LumaReferenceChain = None,
|
||||
style_image: torch.Tensor = None,
|
||||
character_image: torch.Tensor = None,
|
||||
auth_token=None,
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=True, min_length=3)
|
||||
@@ -222,19 +235,19 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
api_image_ref = None
|
||||
if image_luma_ref is not None:
|
||||
api_image_ref = self._convert_luma_refs(
|
||||
image_luma_ref, max_refs=4, auth_token=auth_token
|
||||
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle style_luma_ref
|
||||
api_style_ref = None
|
||||
if style_image is not None:
|
||||
api_style_ref = self._convert_style_image(
|
||||
style_image, weight=style_image_weight, auth_token=auth_token
|
||||
style_image, weight=style_image_weight, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle character_ref images
|
||||
character_ref = None
|
||||
if character_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
character_image, max_images=4, auth_token=auth_token
|
||||
character_image, max_images=4, auth_kwargs=kwargs,
|
||||
)
|
||||
character_ref = LumaCharacterRef(
|
||||
identity0=LumaImageIdentity(images=download_urls)
|
||||
@@ -255,7 +268,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
style_ref=api_style_ref,
|
||||
character_ref=character_ref,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
@@ -269,7 +282,9 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
auth_token=auth_token,
|
||||
result_url_extractor=image_result_url_extractor,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
@@ -278,13 +293,13 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
return (img,)
|
||||
|
||||
def _convert_luma_refs(
|
||||
self, luma_ref: LumaReferenceChain, max_refs: int, auth_token=None
|
||||
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
luma_urls = []
|
||||
ref_count = 0
|
||||
for ref in luma_ref.refs:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
ref.image, max_images=1, auth_token=auth_token
|
||||
ref.image, max_images=1, auth_kwargs=auth_kwargs
|
||||
)
|
||||
luma_urls.append(download_urls[0])
|
||||
ref_count += 1
|
||||
@@ -293,12 +308,12 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
||||
|
||||
def _convert_style_image(
|
||||
self, style_image: torch.Tensor, weight: float, auth_token=None
|
||||
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
chain = LumaReferenceChain(
|
||||
first_ref=LumaReference(image=style_image, weight=weight)
|
||||
)
|
||||
return self._convert_luma_refs(chain, max_refs=1, auth_token=auth_token)
|
||||
return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||
|
||||
|
||||
class LumaImageModifyNode(ComfyNodeABC):
|
||||
@@ -350,6 +365,8 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -360,12 +377,12 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
image: torch.Tensor,
|
||||
image_weight: float,
|
||||
seed,
|
||||
auth_token=None,
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
# first, upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_token=auth_token
|
||||
image, max_images=1, auth_kwargs=kwargs,
|
||||
)
|
||||
image_url = download_urls[0]
|
||||
# next, make Luma call with download url provided
|
||||
@@ -383,7 +400,7 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
|
||||
),
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
@@ -397,7 +414,9 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
auth_token=auth_token,
|
||||
result_url_extractor=image_result_url_extractor,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
@@ -470,6 +489,8 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -483,7 +504,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
loop: bool,
|
||||
seed,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
auth_token=None,
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False, min_length=3)
|
||||
@@ -506,10 +527,13 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
loop=loop,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
@@ -520,7 +544,10 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
auth_token=auth_token,
|
||||
result_url_extractor=video_result_url_extractor,
|
||||
node_id=unique_id,
|
||||
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
@@ -594,6 +621,8 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -608,14 +637,14 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
auth_token=None,
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
if first_image is None and last_image is None:
|
||||
raise Exception(
|
||||
"At least one of first_image and last_image requires an input."
|
||||
)
|
||||
keyframes = self._convert_to_keyframes(first_image, last_image, auth_token)
|
||||
keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
@@ -636,10 +665,13 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
keyframes=keyframes,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
@@ -650,7 +682,10 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
auth_token=auth_token,
|
||||
result_url_extractor=video_result_url_extractor,
|
||||
node_id=unique_id,
|
||||
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
@@ -661,7 +696,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
self,
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
auth_token=None,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
):
|
||||
if first_image is None and last_image is None:
|
||||
return None
|
||||
@@ -669,12 +704,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
frame1 = None
|
||||
if first_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
first_image, max_images=1, auth_token=auth_token
|
||||
first_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
||||
if last_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
last_image, max_images=1, auth_token=auth_token
|
||||
last_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
||||
return LumaKeyframes(frame0=frame0, frame1=frame1)
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
@@ -20,16 +24,19 @@ from comfy_api_nodes.apinode_utils import (
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
import torch
|
||||
import logging
|
||||
|
||||
I2V_AVERAGE_DURATION = 114
|
||||
T2V_AVERAGE_DURATION = 234
|
||||
|
||||
class MinimaxTextToVideoNode:
|
||||
"""
|
||||
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
AVERAGE_DURATION = T2V_AVERAGE_DURATION
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@@ -67,6 +74,8 @@ class MinimaxTextToVideoNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -84,7 +93,8 @@ class MinimaxTextToVideoNode:
|
||||
model="T2V-01",
|
||||
image: torch.Tensor=None, # used for ImageToVideo
|
||||
subject: torch.Tensor=None, # used for SubjectToVideo
|
||||
auth_token=None,
|
||||
unique_id: Union[str, None]=None,
|
||||
**kwargs,
|
||||
):
|
||||
'''
|
||||
Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments.
|
||||
@@ -94,12 +104,12 @@ class MinimaxTextToVideoNode:
|
||||
# upload image, if passed in
|
||||
image_url = None
|
||||
if image is not None:
|
||||
image_url = upload_images_to_comfyapi(image, max_images=1, auth_token=auth_token)[0]
|
||||
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0]
|
||||
|
||||
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
||||
subject_reference = None
|
||||
if subject is not None:
|
||||
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_token=auth_token)[0]
|
||||
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0]
|
||||
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
||||
|
||||
|
||||
@@ -118,7 +128,7 @@ class MinimaxTextToVideoNode:
|
||||
subject_reference=subject_reference,
|
||||
prompt_optimizer=None,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response = video_generate_operation.execute()
|
||||
|
||||
@@ -137,7 +147,9 @@ class MinimaxTextToVideoNode:
|
||||
completed_statuses=["Success"],
|
||||
failed_statuses=["Fail"],
|
||||
status_extractor=lambda x: x.status.value,
|
||||
auth_token=auth_token,
|
||||
estimated_duration=self.AVERAGE_DURATION,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_result = video_generate_operation.execute()
|
||||
|
||||
@@ -153,7 +165,7 @@ class MinimaxTextToVideoNode:
|
||||
query_params={"file_id": int(file_id)},
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
file_result = file_retrieve_operation.execute()
|
||||
|
||||
@@ -163,6 +175,12 @@ class MinimaxTextToVideoNode:
|
||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||
)
|
||||
logging.info(f"Generated video URL: {file_url}")
|
||||
if unique_id:
|
||||
if hasattr(file_result.file, "backup_download_url"):
|
||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||
else:
|
||||
message = f"Result URL: {file_url}"
|
||||
PromptServer.instance.send_progress_text(message, unique_id)
|
||||
|
||||
video_io = download_url_to_bytesio(file_url)
|
||||
if video_io is None:
|
||||
@@ -177,6 +195,8 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
|
||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
AVERAGE_DURATION = I2V_AVERAGE_DURATION
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@@ -221,6 +241,8 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -237,6 +259,8 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
|
||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
AVERAGE_DURATION = T2V_AVERAGE_DURATION
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@@ -279,6 +303,8 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -93,7 +93,11 @@ class OpenAIDalle2(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
@@ -110,7 +114,8 @@ class OpenAIDalle2(ComfyNodeABC):
|
||||
mask=None,
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
auth_token=None,
|
||||
unique_id=None,
|
||||
**kwargs
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "dall-e-2"
|
||||
@@ -168,12 +173,12 @@ class OpenAIDalle2(ComfyNodeABC):
|
||||
else None
|
||||
),
|
||||
content_type=content_type,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response)
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
@@ -236,7 +241,11 @@ class OpenAIDalle3(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
@@ -252,7 +261,8 @@ class OpenAIDalle3(ComfyNodeABC):
|
||||
style="natural",
|
||||
quality="standard",
|
||||
size="1024x1024",
|
||||
auth_token=None,
|
||||
unique_id=None,
|
||||
**kwargs
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "dall-e-3"
|
||||
@@ -273,12 +283,12 @@ class OpenAIDalle3(ComfyNodeABC):
|
||||
style=style,
|
||||
seed=seed,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response)
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
@@ -366,7 +376,11 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
@@ -385,7 +399,8 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
mask=None,
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
auth_token=None,
|
||||
unique_id=None,
|
||||
**kwargs
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "gpt-image-1"
|
||||
@@ -462,12 +477,12 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
),
|
||||
files=files if files else None,
|
||||
content_type=content_type,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response)
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ Pika x ComfyUI API Nodes
|
||||
|
||||
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from typing import Optional, TypeVar
|
||||
@@ -120,7 +121,10 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
|
||||
def poll_for_task_status(
|
||||
self, task_id: str, auth_token: str
|
||||
self,
|
||||
task_id: str,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> PikaGenerateResponse:
|
||||
polling_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
@@ -139,20 +143,26 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
progress_extractor=lambda response: (
|
||||
response.progress if hasattr(response, "progress") else None
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=lambda response: (
|
||||
response.url if hasattr(response, "url") else None
|
||||
),
|
||||
node_id=node_id,
|
||||
estimated_duration=60
|
||||
)
|
||||
return polling_operation.execute()
|
||||
|
||||
def execute_task(
|
||||
self,
|
||||
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
|
||||
auth_token: Optional[str] = None,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> tuple[VideoFromFile]:
|
||||
"""Executes the initial operation then polls for the task status until it is completed.
|
||||
|
||||
Args:
|
||||
initial_operation: The initial operation to execute.
|
||||
auth_token: The authentication token to use for the API call.
|
||||
auth_kwargs: The authentication token(s) to use for the API call.
|
||||
|
||||
Returns:
|
||||
A tuple containing the video file as a VIDEO output.
|
||||
@@ -164,7 +174,7 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
raise PikaApiError(error_msg)
|
||||
|
||||
task_id = initial_response.video_id
|
||||
final_response = self.poll_for_task_status(task_id, auth_token)
|
||||
final_response = self.poll_for_task_status(task_id, auth_kwargs)
|
||||
if not is_valid_video_response(final_response):
|
||||
error_msg = (
|
||||
f"Pika task {task_id} succeeded but no video data found in response."
|
||||
@@ -193,6 +203,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -206,7 +217,8 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert image to BytesIO
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
@@ -233,10 +245,10 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_token)
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
@@ -259,6 +271,8 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -272,7 +286,8 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
resolution: str,
|
||||
duration: int,
|
||||
aspect_ratio: float,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -289,11 +304,11 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_token)
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaScenesV2_2(PikaNodeBase):
|
||||
@@ -336,6 +351,8 @@ class PikaScenesV2_2(PikaNodeBase):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -350,12 +367,13 @@ class PikaScenesV2_2(PikaNodeBase):
|
||||
duration: int,
|
||||
ingredients_mode: str,
|
||||
aspect_ratio: float,
|
||||
unique_id: str,
|
||||
image_ingredient_1: Optional[torch.Tensor] = None,
|
||||
image_ingredient_2: Optional[torch.Tensor] = None,
|
||||
image_ingredient_3: Optional[torch.Tensor] = None,
|
||||
image_ingredient_4: Optional[torch.Tensor] = None,
|
||||
image_ingredient_5: Optional[torch.Tensor] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert all passed images to BytesIO
|
||||
all_image_bytes_io = []
|
||||
@@ -396,10 +414,10 @@ class PikaScenesV2_2(PikaNodeBase):
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_token)
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikAdditionsNode(PikaNodeBase):
|
||||
@@ -434,6 +452,8 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -446,7 +466,8 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert video to BytesIO
|
||||
video_bytes_io = io.BytesIO()
|
||||
@@ -479,10 +500,10 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_token)
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaSwapsNode(PikaNodeBase):
|
||||
@@ -526,6 +547,8 @@ class PikaSwapsNode(PikaNodeBase):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -540,7 +563,8 @@ class PikaSwapsNode(PikaNodeBase):
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert video to BytesIO
|
||||
video_bytes_io = io.BytesIO()
|
||||
@@ -583,10 +607,10 @@ class PikaSwapsNode(PikaNodeBase):
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_token)
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaffectsNode(PikaNodeBase):
|
||||
@@ -630,6 +654,8 @@ class PikaffectsNode(PikaNodeBase):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -642,7 +668,8 @@ class PikaffectsNode(PikaNodeBase):
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
@@ -660,10 +687,10 @@ class PikaffectsNode(PikaNodeBase):
|
||||
),
|
||||
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_token)
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
@@ -681,6 +708,8 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -695,7 +724,8 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
|
||||
pika_files = [
|
||||
@@ -722,10 +752,10 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
),
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_token)
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from inspect import cleandoc
|
||||
|
||||
from typing import Optional
|
||||
from comfy_api_nodes.apis.pixverse_api import (
|
||||
PixverseTextVideoRequest,
|
||||
PixverseImageVideoRequest,
|
||||
@@ -34,11 +34,22 @@ import requests
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def upload_image_to_pixverse(image: torch.Tensor, auth_token=None):
|
||||
AVERAGE_DURATION_T2V = 32
|
||||
AVERAGE_DURATION_I2V = 30
|
||||
AVERAGE_DURATION_T2T = 52
|
||||
|
||||
|
||||
def get_video_url_from_response(
|
||||
response: PixverseGenerationStatusResponse,
|
||||
) -> Optional[str]:
|
||||
if response.Resp is None or response.Resp.url is None:
|
||||
return None
|
||||
return str(response.Resp.url)
|
||||
|
||||
|
||||
def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
# first, upload image to Pixverse and get image id to use in actual generation call
|
||||
files = {
|
||||
"image": tensor_to_bytesio(image)
|
||||
}
|
||||
files = {"image": tensor_to_bytesio(image)}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/image/upload",
|
||||
@@ -49,12 +60,14 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_token=None):
|
||||
request=EmptyRequest(),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_upload: PixverseImageUploadResponse = operation.execute()
|
||||
|
||||
if response_upload.Resp is None:
|
||||
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
|
||||
raise Exception(
|
||||
f"PixVerse image upload request failed: '{response_upload.ErrMsg}'"
|
||||
)
|
||||
|
||||
return response_upload.Resp.img_id
|
||||
|
||||
@@ -73,7 +86,7 @@ class PixverseTemplateNode:
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"template": (list(pixverse_templates.keys()), ),
|
||||
"template": (list(pixverse_templates.keys()),),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,7 +100,7 @@ class PixverseTemplateNode:
|
||||
|
||||
class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos synchronously based on prompt and output_size.
|
||||
Generates videos based on prompt and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
@@ -108,9 +121,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
"tooltip": "Prompt for the video generation",
|
||||
},
|
||||
),
|
||||
"aspect_ratio": (
|
||||
[ratio.value for ratio in PixverseAspectRatio],
|
||||
),
|
||||
"aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],),
|
||||
"quality": (
|
||||
[resolution.value for resolution in PixverseQuality],
|
||||
{
|
||||
@@ -143,11 +154,13 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
PixverseIO.TEMPLATE,
|
||||
{
|
||||
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
|
||||
}
|
||||
)
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -159,9 +172,9 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
duration_seconds: int,
|
||||
motion_mode: str,
|
||||
seed,
|
||||
negative_prompt: str=None,
|
||||
pixverse_template: int=None,
|
||||
auth_token=None,
|
||||
negative_prompt: str = None,
|
||||
pixverse_template: int = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
@@ -190,7 +203,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
template_id=pixverse_template,
|
||||
seed=seed,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
@@ -205,19 +218,27 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
|
||||
failed_statuses=[
|
||||
PixverseStatus.contents_moderation,
|
||||
PixverseStatus.failed,
|
||||
PixverseStatus.deleted,
|
||||
],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
|
||||
|
||||
class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos synchronously based on prompt and output_size.
|
||||
Generates videos based on prompt and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
@@ -230,9 +251,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
),
|
||||
"image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
@@ -273,11 +292,13 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
PixverseIO.TEMPLATE,
|
||||
{
|
||||
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
|
||||
}
|
||||
)
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -289,13 +310,13 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
duration_seconds: int,
|
||||
motion_mode: str,
|
||||
seed,
|
||||
negative_prompt: str=None,
|
||||
pixverse_template: int=None,
|
||||
auth_token=None,
|
||||
negative_prompt: str = None,
|
||||
pixverse_template: int = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
img_id = upload_image_to_pixverse(image, auth_token=auth_token)
|
||||
img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
@@ -322,7 +343,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
template_id=pixverse_template,
|
||||
seed=seed,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
@@ -337,9 +358,16 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
|
||||
failed_statuses=[
|
||||
PixverseStatus.contents_moderation,
|
||||
PixverseStatus.failed,
|
||||
PixverseStatus.deleted,
|
||||
],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
@@ -349,7 +377,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
|
||||
class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos synchronously based on prompt and output_size.
|
||||
Generates videos based on prompt and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
@@ -362,12 +390,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"first_frame": (
|
||||
IO.IMAGE,
|
||||
),
|
||||
"last_frame": (
|
||||
IO.IMAGE,
|
||||
),
|
||||
"first_frame": (IO.IMAGE,),
|
||||
"last_frame": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
@@ -407,6 +431,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -419,13 +445,13 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
duration_seconds: int,
|
||||
motion_mode: str,
|
||||
seed,
|
||||
negative_prompt: str=None,
|
||||
auth_token=None,
|
||||
negative_prompt: str = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
first_frame_id = upload_image_to_pixverse(first_frame, auth_token=auth_token)
|
||||
last_frame_id = upload_image_to_pixverse(last_frame, auth_token=auth_token)
|
||||
first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
|
||||
last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
@@ -452,7 +478,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
seed=seed,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
@@ -467,9 +493,16 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
|
||||
failed_statuses=[
|
||||
PixverseStatus.contents_moderation,
|
||||
PixverseStatus.failed,
|
||||
PixverseStatus.deleted,
|
||||
],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy_extras.nodes_images import SVG # Added
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api_nodes.apis.recraft_api import (
|
||||
RecraftImageGenerationRequest,
|
||||
@@ -28,9 +30,8 @@ from comfy_api_nodes.apinode_utils import (
|
||||
resize_mask_to_image,
|
||||
validate_string,
|
||||
)
|
||||
import folder_paths
|
||||
import json
|
||||
import os
|
||||
from server import PromptServer
|
||||
|
||||
import torch
|
||||
from io import BytesIO
|
||||
from PIL import UnidentifiedImageError
|
||||
@@ -43,7 +44,7 @@ def handle_recraft_file_request(
|
||||
total_pixels=4096*4096,
|
||||
timeout=1024,
|
||||
request=None,
|
||||
auth_token=None
|
||||
auth_kwargs: dict[str,str] = None,
|
||||
) -> list[BytesIO]:
|
||||
"""
|
||||
Handle sending common Recraft file-only request to get back file bytes.
|
||||
@@ -67,7 +68,7 @@ def handle_recraft_file_request(
|
||||
request=request,
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=auth_kwargs,
|
||||
multipart_parser=recraft_multipart_parser,
|
||||
)
|
||||
response: RecraftImageGenerationResponse = operation.execute()
|
||||
@@ -162,102 +163,6 @@ class handle_recraft_image_output:
|
||||
raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.")
|
||||
|
||||
|
||||
class SVG:
|
||||
"""
|
||||
Stores SVG representations via a list of BytesIO objects.
|
||||
"""
|
||||
def __init__(self, data: list[BytesIO]):
|
||||
self.data = data
|
||||
|
||||
def combine(self, other: SVG):
|
||||
return SVG(self.data + other.data)
|
||||
|
||||
@staticmethod
|
||||
def combine_all(svgs: list[SVG]):
|
||||
all_svgs = []
|
||||
for svg in svgs:
|
||||
all_svgs.extend(svg.data)
|
||||
return SVG(all_svgs)
|
||||
|
||||
|
||||
class SaveSVGNode:
|
||||
"""
|
||||
Save SVG files on disk.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
RETURN_TYPES = ()
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "save_svg"
|
||||
CATEGORY = "api node/image/Recraft"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"svg": (RecraftIO.SVG,),
|
||||
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO"
|
||||
}
|
||||
}
|
||||
|
||||
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results = list()
|
||||
|
||||
# Prepare metadata JSON
|
||||
metadata_dict = {}
|
||||
if prompt is not None:
|
||||
metadata_dict["prompt"] = prompt
|
||||
if extra_pnginfo is not None:
|
||||
metadata_dict.update(extra_pnginfo)
|
||||
|
||||
# Convert metadata to JSON string
|
||||
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
|
||||
|
||||
for batch_number, svg_bytes in enumerate(svg.data):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.svg"
|
||||
|
||||
# Read SVG content
|
||||
svg_bytes.seek(0)
|
||||
svg_content = svg_bytes.read().decode('utf-8')
|
||||
|
||||
# Inject metadata if available
|
||||
if metadata_json:
|
||||
# Create metadata element with CDATA section
|
||||
metadata_element = f""" <metadata>
|
||||
<![CDATA[
|
||||
{metadata_json}
|
||||
]]>
|
||||
</metadata>
|
||||
"""
|
||||
# Insert metadata after opening svg tag using regex
|
||||
import re
|
||||
svg_content = re.sub(r'(<svg[^>]*>)', r'\1\n' + metadata_element, svg_content)
|
||||
|
||||
# Write the modified SVG to file
|
||||
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
|
||||
svg_file.write(svg_content.encode('utf-8'))
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
return { "ui": { "images": results } }
|
||||
|
||||
|
||||
class RecraftColorRGBNode:
|
||||
"""
|
||||
Create Recraft Color by choosing specific RGB values.
|
||||
@@ -485,6 +390,8 @@ class RecraftTextToImageNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -497,7 +404,7 @@ class RecraftTextToImageNode:
|
||||
recraft_style: RecraftStyle = None,
|
||||
negative_prompt: str = None,
|
||||
recraft_controls: RecraftControls = None,
|
||||
auth_token=None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False, max_length=1000)
|
||||
@@ -530,12 +437,19 @@ class RecraftTextToImageNode:
|
||||
style_id=recraft_style.style_id,
|
||||
controls=controls_api,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response: RecraftImageGenerationResponse = operation.execute()
|
||||
images = []
|
||||
urls = []
|
||||
for data in response.data:
|
||||
with handle_recraft_image_output():
|
||||
if unique_id and data.url:
|
||||
urls.append(data.url)
|
||||
urls_string = '\n'.join(urls)
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {urls_string}", unique_id
|
||||
)
|
||||
image = bytesio_to_image_tensor(
|
||||
download_url_to_bytesio(data.url, timeout=1024)
|
||||
)
|
||||
@@ -620,6 +534,7 @@ class RecraftImageToImageNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -630,7 +545,6 @@ class RecraftImageToImageNode:
|
||||
n: int,
|
||||
strength: float,
|
||||
seed,
|
||||
auth_token=None,
|
||||
recraft_style: RecraftStyle = None,
|
||||
negative_prompt: str = None,
|
||||
recraft_controls: RecraftControls = None,
|
||||
@@ -668,7 +582,7 @@ class RecraftImageToImageNode:
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/imageToImage",
|
||||
request=request,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
with handle_recraft_image_output():
|
||||
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
|
||||
@@ -736,6 +650,7 @@ class RecraftImageInpaintingNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -746,7 +661,6 @@ class RecraftImageInpaintingNode:
|
||||
prompt: str,
|
||||
n: int,
|
||||
seed,
|
||||
auth_token=None,
|
||||
recraft_style: RecraftStyle = None,
|
||||
negative_prompt: str = None,
|
||||
**kwargs,
|
||||
@@ -781,7 +695,7 @@ class RecraftImageInpaintingNode:
|
||||
mask=mask[i:i+1],
|
||||
path="/proxy/recraft/images/inpaint",
|
||||
request=request,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
with handle_recraft_image_output():
|
||||
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
|
||||
@@ -796,8 +710,8 @@ class RecraftTextToVectorNode:
|
||||
Generates SVG synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (RecraftIO.SVG,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
RETURN_TYPES = ("SVG",) # Changed
|
||||
DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Recraft"
|
||||
@@ -860,6 +774,8 @@ class RecraftTextToVectorNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -872,7 +788,7 @@ class RecraftTextToVectorNode:
|
||||
seed,
|
||||
negative_prompt: str = None,
|
||||
recraft_controls: RecraftControls = None,
|
||||
auth_token=None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False, max_length=1000)
|
||||
@@ -903,11 +819,18 @@ class RecraftTextToVectorNode:
|
||||
substyle=recraft_style.substyle,
|
||||
controls=controls_api,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response: RecraftImageGenerationResponse = operation.execute()
|
||||
svg_data = []
|
||||
urls = []
|
||||
for data in response.data:
|
||||
if unique_id and data.url:
|
||||
urls.append(data.url)
|
||||
# Print result on each iteration in case of error
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {' '.join(urls)}", unique_id
|
||||
)
|
||||
svg_data.append(download_url_to_bytesio(data.url, timeout=1024))
|
||||
|
||||
return (SVG(svg_data),)
|
||||
@@ -918,8 +841,8 @@ class RecraftVectorizeImageNode:
|
||||
Generates SVG synchronously from an input image.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (RecraftIO.SVG,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
RETURN_TYPES = ("SVG",) # Changed
|
||||
DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Recraft"
|
||||
@@ -934,13 +857,13 @@ class RecraftVectorizeImageNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
auth_token=None,
|
||||
**kwargs,
|
||||
):
|
||||
svgs = []
|
||||
@@ -950,7 +873,7 @@ class RecraftVectorizeImageNode:
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/vectorize",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
svgs.append(SVG(sub_bytes))
|
||||
pbar.update(1)
|
||||
@@ -1015,6 +938,7 @@ class RecraftReplaceBackgroundNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1024,7 +948,6 @@ class RecraftReplaceBackgroundNode:
|
||||
prompt: str,
|
||||
n: int,
|
||||
seed,
|
||||
auth_token=None,
|
||||
recraft_style: RecraftStyle = None,
|
||||
negative_prompt: str = None,
|
||||
**kwargs,
|
||||
@@ -1054,7 +977,7 @@ class RecraftReplaceBackgroundNode:
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/replaceBackground",
|
||||
request=request,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
|
||||
pbar.update(1)
|
||||
@@ -1084,13 +1007,13 @@ class RecraftRemoveBackgroundNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
auth_token=None,
|
||||
**kwargs,
|
||||
):
|
||||
images = []
|
||||
@@ -1100,7 +1023,7 @@ class RecraftRemoveBackgroundNode:
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/removeBackground",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
|
||||
pbar.update(1)
|
||||
@@ -1135,13 +1058,13 @@ class RecraftCrispUpscaleNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
auth_token=None,
|
||||
**kwargs,
|
||||
):
|
||||
images = []
|
||||
@@ -1151,7 +1074,7 @@ class RecraftCrispUpscaleNode:
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path=self.RECRAFT_PATH,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
|
||||
pbar.update(1)
|
||||
@@ -1193,7 +1116,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"RecraftStyleV3InfiniteStyleLibrary": RecraftStyleInfiniteStyleLibrary,
|
||||
"RecraftColorRGB": RecraftColorRGBNode,
|
||||
"RecraftControls": RecraftControlsNode,
|
||||
"SaveSVG": SaveSVGNode,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
@@ -1213,5 +1135,4 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"RecraftStyleV3InfiniteStyleLibrary": "Recraft Style - Infinite Style Library",
|
||||
"RecraftColorRGB": "Recraft Color RGB",
|
||||
"RecraftControls": "Recraft Controls",
|
||||
"SaveSVG": "Save SVG",
|
||||
}
|
||||
|
||||
@@ -120,12 +120,13 @@ class StabilityStableImageUltraNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
|
||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||
auth_token=None):
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
@@ -160,7 +161,7 @@ class StabilityStableImageUltraNode:
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
@@ -252,12 +253,13 @@ class StabilityStableImageSD_3_5Node:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
|
||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||
auth_token=None):
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
@@ -298,7 +300,7 @@ class StabilityStableImageSD_3_5Node:
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
@@ -368,11 +370,12 @@ class StabilityUpscaleConservativeNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
|
||||
auth_token=None):
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
@@ -398,7 +401,7 @@ class StabilityUpscaleConservativeNode:
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
@@ -473,11 +476,12 @@ class StabilityUpscaleCreativeNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
|
||||
auth_token=None):
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
@@ -506,7 +510,7 @@ class StabilityUpscaleCreativeNode:
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
@@ -521,7 +525,7 @@ class StabilityUpscaleCreativeNode:
|
||||
completed_statuses=[StabilityPollStatus.finished],
|
||||
failed_statuses=[StabilityPollStatus.failed],
|
||||
status_extractor=lambda x: get_async_dummy_status(x),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll: StabilityResultsGetResponse = operation.execute()
|
||||
|
||||
@@ -555,11 +559,12 @@ class StabilityUpscaleFastNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor,
|
||||
auth_token=None):
|
||||
**kwargs):
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
||||
|
||||
files = {
|
||||
@@ -576,7 +581,7 @@ class StabilityUpscaleFastNode:
|
||||
request=EmptyRequest(),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import base64
|
||||
import requests
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
@@ -24,6 +25,8 @@ from comfy_api_nodes.apinode_utils import (
|
||||
tensor_to_base64_string
|
||||
)
|
||||
|
||||
AVERAGE_DURATION_VIDEO_GEN = 32
|
||||
|
||||
def convert_image_to_base64(image: torch.Tensor):
|
||||
if image is None:
|
||||
return None
|
||||
@@ -31,6 +34,22 @@ def convert_image_to_base64(image: torch.Tensor):
|
||||
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
|
||||
return tensor_to_base64_string(scaled_image)
|
||||
|
||||
|
||||
def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]:
|
||||
if (
|
||||
poll_response.response
|
||||
and hasattr(poll_response.response, "videos")
|
||||
and poll_response.response.videos
|
||||
and len(poll_response.response.videos) > 0
|
||||
):
|
||||
video = poll_response.response.videos[0]
|
||||
else:
|
||||
return None
|
||||
if hasattr(video, "gcsUri") and video.gcsUri:
|
||||
return str(video.gcsUri)
|
||||
return None
|
||||
|
||||
|
||||
class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos from text prompts using Google's Veo API.
|
||||
@@ -114,6 +133,8 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -133,7 +154,8 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
person_generation="ALLOW",
|
||||
seed=0,
|
||||
image=None,
|
||||
auth_token=None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Prepare the instances for the request
|
||||
instances = []
|
||||
@@ -179,7 +201,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
instances=instances,
|
||||
parameters=parameters
|
||||
),
|
||||
auth_token=auth_token
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
@@ -213,8 +235,11 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
request=Veo2GenVidPollRequest(
|
||||
operationName=operation_name
|
||||
),
|
||||
auth_token=auth_token,
|
||||
poll_interval=5.0
|
||||
auth_kwargs=kwargs,
|
||||
poll_interval=5.0,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
node_id=unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||
)
|
||||
|
||||
# Execute the polling operation
|
||||
|
||||
0
comfy_api_nodes/util/__init__.py
Normal file
0
comfy_api_nodes/util/__init__.py
Normal file
100
comfy_api_nodes/util/validation_utils.py
Normal file
100
comfy_api_nodes/util/validation_utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
|
||||
|
||||
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
|
||||
if len(image.shape) == 4:
|
||||
return image.shape[1], image.shape[2]
|
||||
elif len(image.shape) == 3:
|
||||
return image.shape[0], image.shape[1]
|
||||
else:
|
||||
raise ValueError("Invalid image tensor shape.")
|
||||
|
||||
|
||||
def validate_image_dimensions(
|
||||
image: torch.Tensor,
|
||||
min_width: Optional[int] = None,
|
||||
max_width: Optional[int] = None,
|
||||
min_height: Optional[int] = None,
|
||||
max_height: Optional[int] = None,
|
||||
):
|
||||
height, width = get_image_dimensions(image)
|
||||
|
||||
if min_width is not None and width < min_width:
|
||||
raise ValueError(f"Image width must be at least {min_width}px, got {width}px")
|
||||
if max_width is not None and width > max_width:
|
||||
raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
|
||||
if min_height is not None and height < min_height:
|
||||
raise ValueError(
|
||||
f"Image height must be at least {min_height}px, got {height}px"
|
||||
)
|
||||
if max_height is not None and height > max_height:
|
||||
raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
|
||||
|
||||
|
||||
def validate_image_aspect_ratio(
|
||||
image: torch.Tensor,
|
||||
min_aspect_ratio: Optional[float] = None,
|
||||
max_aspect_ratio: Optional[float] = None,
|
||||
):
|
||||
width, height = get_image_dimensions(image)
|
||||
aspect_ratio = width / height
|
||||
|
||||
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
|
||||
raise ValueError(
|
||||
f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}"
|
||||
)
|
||||
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
|
||||
raise ValueError(
|
||||
f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}"
|
||||
)
|
||||
|
||||
|
||||
def validate_video_dimensions(
|
||||
video: VideoInput,
|
||||
min_width: Optional[int] = None,
|
||||
max_width: Optional[int] = None,
|
||||
min_height: Optional[int] = None,
|
||||
max_height: Optional[int] = None,
|
||||
):
|
||||
try:
|
||||
width, height = video.get_dimensions()
|
||||
except Exception as e:
|
||||
logging.error("Error getting dimensions of video: %s", e)
|
||||
return
|
||||
|
||||
if min_width is not None and width < min_width:
|
||||
raise ValueError(f"Video width must be at least {min_width}px, got {width}px")
|
||||
if max_width is not None and width > max_width:
|
||||
raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
|
||||
if min_height is not None and height < min_height:
|
||||
raise ValueError(
|
||||
f"Video height must be at least {min_height}px, got {height}px"
|
||||
)
|
||||
if max_height is not None and height > max_height:
|
||||
raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
|
||||
|
||||
|
||||
def validate_video_duration(
|
||||
video: VideoInput,
|
||||
min_duration: Optional[float] = None,
|
||||
max_duration: Optional[float] = None,
|
||||
):
|
||||
try:
|
||||
duration = video.get_duration()
|
||||
except Exception as e:
|
||||
logging.error("Error getting duration of video: %s", e)
|
||||
return
|
||||
|
||||
epsilon = 0.0001
|
||||
if min_duration is not None and min_duration - epsilon > duration:
|
||||
raise ValueError(
|
||||
f"Video duration must be at least {min_duration}s, got {duration}s"
|
||||
)
|
||||
if max_duration is not None and duration > max_duration + epsilon:
|
||||
raise ValueError(
|
||||
f"Video duration must be at most {max_duration}s, got {duration}s"
|
||||
)
|
||||
49
comfy_extras/nodes_ace.py
Normal file
49
comfy_extras/nodes_ace.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import node_helpers
|
||||
|
||||
class TextEncodeAceStepAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def encode(self, clip, tags, lyrics, lyrics_strength):
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
||||
return (conditioning, )
|
||||
|
||||
|
||||
class EmptyAceStepLatentAudio:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def generate(self, seconds, batch_size):
|
||||
length = int(seconds * 44100 / 512 / 8)
|
||||
latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
|
||||
return ({"samples": latent, "type": "audio"}, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TextEncodeAceStepAudio": TextEncodeAceStepAudio,
|
||||
"EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
|
||||
}
|
||||
76
comfy_extras/nodes_apg.py
Normal file
76
comfy_extras/nodes_apg.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
|
||||
def project(v0, v1):
|
||||
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
|
||||
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
return v0_parallel, v0_orthogonal
|
||||
|
||||
class APG:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
|
||||
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
|
||||
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
|
||||
def patch(self, model, eta, norm_threshold, momentum):
|
||||
running_avg = 0
|
||||
prev_sigma = None
|
||||
|
||||
def pre_cfg_function(args):
|
||||
nonlocal running_avg, prev_sigma
|
||||
|
||||
if len(args["conds_out"]) == 1: return args["conds_out"]
|
||||
|
||||
cond = args["conds_out"][0]
|
||||
uncond = args["conds_out"][1]
|
||||
sigma = args["sigma"][0]
|
||||
cond_scale = args["cond_scale"]
|
||||
|
||||
if prev_sigma is not None and sigma > prev_sigma:
|
||||
running_avg = 0
|
||||
prev_sigma = sigma
|
||||
|
||||
guidance = cond - uncond
|
||||
|
||||
if momentum != 0:
|
||||
if not torch.is_tensor(running_avg):
|
||||
running_avg = guidance
|
||||
else:
|
||||
running_avg = momentum * running_avg + guidance
|
||||
guidance = running_avg
|
||||
|
||||
if norm_threshold > 0:
|
||||
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
||||
scale = torch.minimum(
|
||||
torch.ones_like(guidance_norm),
|
||||
norm_threshold / guidance_norm
|
||||
)
|
||||
guidance = guidance * scale
|
||||
|
||||
guidance_parallel, guidance_orthogonal = project(guidance, cond)
|
||||
modified_guidance = guidance_orthogonal + eta * guidance_parallel
|
||||
|
||||
modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale
|
||||
|
||||
return [modified_cond, uncond] + args["conds_out"][2:]
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||
return (m,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"APG": APG,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"APG": "Adaptive Projected Guidance",
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import av
|
||||
import torchaudio
|
||||
import torch
|
||||
import comfy.model_management
|
||||
@@ -7,7 +8,6 @@ import folder_paths
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import struct
|
||||
import random
|
||||
import hashlib
|
||||
import node_helpers
|
||||
@@ -90,60 +90,118 @@ class VAEDecodeAudio:
|
||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||
|
||||
|
||||
def create_vorbis_comment_block(comment_dict, last_block):
|
||||
vendor_string = b'ComfyUI'
|
||||
vendor_length = len(vendor_string)
|
||||
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
|
||||
|
||||
comments = []
|
||||
for key, value in comment_dict.items():
|
||||
comment = f"{key}={value}".encode('utf-8')
|
||||
comments.append(struct.pack('<I', len(comment)) + comment)
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results: list[FileLocator] = []
|
||||
|
||||
user_comment_list_length = len(comments)
|
||||
user_comments = b''.join(comments)
|
||||
# Prepare metadata dictionary
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
comment_data = struct.pack('<I', vendor_length) + vendor_string + struct.pack('<I', user_comment_list_length) + user_comments
|
||||
if last_block:
|
||||
id = b'\x84'
|
||||
else:
|
||||
id = b'\x04'
|
||||
comment_block = id + struct.pack('>I', len(comment_data))[1:] + comment_data
|
||||
# Opus supported sample rates
|
||||
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
|
||||
|
||||
return comment_block
|
||||
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
||||
output_path = os.path.join(full_output_folder, file)
|
||||
|
||||
def insert_or_replace_vorbis_comment(flac_io, comment_dict):
|
||||
if len(comment_dict) == 0:
|
||||
return flac_io
|
||||
# Use original sample rate initially
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
flac_io.seek(4)
|
||||
# Handle Opus sample rate requirements
|
||||
if format == "opus":
|
||||
if sample_rate > 48000:
|
||||
sample_rate = 48000
|
||||
elif sample_rate not in OPUS_RATES:
|
||||
# Find the next highest supported rate
|
||||
for rate in sorted(OPUS_RATES):
|
||||
if rate > sample_rate:
|
||||
sample_rate = rate
|
||||
break
|
||||
if sample_rate not in OPUS_RATES: # Fallback if still not supported
|
||||
sample_rate = 48000
|
||||
|
||||
blocks = []
|
||||
last_block = False
|
||||
# Resample if necessary
|
||||
if sample_rate != audio["sample_rate"]:
|
||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||
|
||||
while not last_block:
|
||||
header = flac_io.read(4)
|
||||
last_block = (header[0] & 0x80) != 0
|
||||
block_type = header[0] & 0x7F
|
||||
block_length = struct.unpack('>I', b'\x00' + header[1:])[0]
|
||||
block_data = flac_io.read(block_length)
|
||||
# Create in-memory WAV buffer
|
||||
wav_buffer = io.BytesIO()
|
||||
torchaudio.save(wav_buffer, waveform, sample_rate, format="WAV")
|
||||
wav_buffer.seek(0) # Rewind for reading
|
||||
|
||||
if block_type == 4 or block_type == 1:
|
||||
pass
|
||||
else:
|
||||
header = bytes([(header[0] & (~0x80))]) + header[1:]
|
||||
blocks.append(header + block_data)
|
||||
# Use PyAV to convert and add metadata
|
||||
input_container = av.open(wav_buffer)
|
||||
|
||||
blocks.append(create_vorbis_comment_block(comment_dict, last_block=True))
|
||||
# Create output with specified format
|
||||
output_buffer = io.BytesIO()
|
||||
output_container = av.open(output_buffer, mode='w', format=format)
|
||||
|
||||
new_flac_io = io.BytesIO()
|
||||
new_flac_io.write(b'fLaC')
|
||||
for block in blocks:
|
||||
new_flac_io.write(block)
|
||||
# Set metadata on the container
|
||||
for key, value in metadata.items():
|
||||
output_container.metadata[key] = value
|
||||
|
||||
new_flac_io.write(flac_io.read())
|
||||
return new_flac_io
|
||||
# Set up the output stream with appropriate properties
|
||||
input_container.streams.audio[0]
|
||||
if format == "opus":
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
||||
if quality == "64k":
|
||||
out_stream.bit_rate = 64000
|
||||
elif quality == "96k":
|
||||
out_stream.bit_rate = 96000
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "192k":
|
||||
out_stream.bit_rate = 192000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
elif format == "mp3":
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
||||
if quality == "V0":
|
||||
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||
out_stream.codec_context.qscale = 1
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
else: #format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||
|
||||
|
||||
# Copy frames from input to output
|
||||
for frame in input_container.decode(audio=0):
|
||||
frame.pts = None # Let PyAV handle timestamps
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output_container.mux(out_stream.encode(None))
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
|
||||
# Write the output to file
|
||||
output_buffer.seek(0)
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(output_buffer.getbuffer())
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
|
||||
return { "ui": { "audio": results } }
|
||||
|
||||
class SaveAudio:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
@@ -153,50 +211,70 @@ class SaveAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_audio"
|
||||
FUNCTION = "save_flac"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results: list[FileLocator] = []
|
||||
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
class SaveAudioMP3:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.flac"
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
buff = io.BytesIO()
|
||||
torchaudio.save(buff, waveform, audio["sample_rate"], format="FLAC")
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_mp3"
|
||||
|
||||
buff = insert_or_replace_vorbis_comment(buff, metadata)
|
||||
OUTPUT_NODE = True
|
||||
|
||||
with open(os.path.join(full_output_folder, file), 'wb') as f:
|
||||
f.write(buff.getbuffer())
|
||||
CATEGORY = "audio"
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
||||
|
||||
return { "ui": { "audio": results } }
|
||||
class SaveAudioOpus:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_opus"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
||||
|
||||
class PreviewAudio(SaveAudio):
|
||||
def __init__(self):
|
||||
@@ -248,7 +326,20 @@ NODE_CLASS_MAPPINGS = {
|
||||
"VAEEncodeAudio": VAEEncodeAudio,
|
||||
"VAEDecodeAudio": VAEDecodeAudio,
|
||||
"SaveAudio": SaveAudio,
|
||||
"SaveAudioMP3": SaveAudioMP3,
|
||||
"SaveAudioOpus": SaveAudioOpus,
|
||||
"LoadAudio": LoadAudio,
|
||||
"PreviewAudio": PreviewAudio,
|
||||
"ConditioningStableAudio": ConditioningStableAudio,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"EmptyLatentAudio": "Empty Latent Audio",
|
||||
"VAEEncodeAudio": "VAE Encode Audio",
|
||||
"VAEDecodeAudio": "VAE Decode Audio",
|
||||
"PreviewAudio": "Preview Audio",
|
||||
"LoadAudio": "Load Audio",
|
||||
"SaveAudio": "Save Audio (FLAC)",
|
||||
"SaveAudioMP3": "Save Audio (MP3)",
|
||||
"SaveAudioOpus": "Save Audio (Opus)",
|
||||
}
|
||||
|
||||
218
comfy_extras/nodes_camera_trajectory.py
Normal file
218
comfy_extras/nodes_camera_trajectory.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import nodes
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
|
||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
CAMERA_DICT = {
|
||||
"base_T_norm": 1.5,
|
||||
"base_angle": np.pi/3,
|
||||
"Static": { "angle":[0., 0., 0.], "T":[0., 0., 0.]},
|
||||
"Pan Up": { "angle":[0., 0., 0.], "T":[0., -1., 0.]},
|
||||
"Pan Down": { "angle":[0., 0., 0.], "T":[0.,1.,0.]},
|
||||
"Pan Left": { "angle":[0., 0., 0.], "T":[-1.,0.,0.]},
|
||||
"Pan Right": { "angle":[0., 0., 0.], "T": [1.,0.,0.]},
|
||||
"Zoom In": { "angle":[0., 0., 0.], "T": [0.,0.,2.]},
|
||||
"Zoom Out": { "angle":[0., 0., 0.], "T": [0.,0.,-2.]},
|
||||
"Anti Clockwise (ACW)": { "angle": [0., 0., -1.], "T":[0., 0., 0.]},
|
||||
"ClockWise (CW)": { "angle": [0., 0., 1.], "T":[0., 0., 0.]},
|
||||
}
|
||||
|
||||
|
||||
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
|
||||
|
||||
def get_relative_pose(cam_params):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
||||
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
||||
cam_to_origin = 0
|
||||
target_cam_c2w = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, -cam_to_origin],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1]
|
||||
])
|
||||
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
||||
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
||||
ret_poses = np.array(ret_poses, dtype=np.float32)
|
||||
return ret_poses
|
||||
|
||||
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
||||
|
||||
sample_wh_ratio = width / height
|
||||
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
||||
|
||||
if pose_wh_ratio > sample_wh_ratio:
|
||||
resized_ori_w = height * pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fx = resized_ori_w * cam_param.fx / width
|
||||
else:
|
||||
resized_ori_h = width / pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fy = resized_ori_h * cam_param.fy / height
|
||||
|
||||
intrinsic = np.asarray([[cam_param.fx * width,
|
||||
cam_param.fy * height,
|
||||
cam_param.cx * width,
|
||||
cam_param.cy * height]
|
||||
for cam_param in cam_params], dtype=np.float32)
|
||||
|
||||
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
||||
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
||||
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
||||
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
||||
plucker_embedding = plucker_embedding[None]
|
||||
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
||||
return plucker_embedding
|
||||
|
||||
class Camera(object):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
def __init__(self, entry):
|
||||
fx, fy, cx, cy = entry[1:5]
|
||||
self.fx = fx
|
||||
self.fy = fy
|
||||
self.cx = cx
|
||||
self.cy = cy
|
||||
c2w_mat = np.array(entry[7:]).reshape(4, 4)
|
||||
self.c2w_mat = c2w_mat
|
||||
self.w2c_mat = np.linalg.inv(c2w_mat)
|
||||
|
||||
def ray_condition(K, c2w, H, W, device):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
# c2w: B, V, 4, 4
|
||||
# K: B, V, 4
|
||||
|
||||
B = K.shape[0]
|
||||
|
||||
j, i = torch.meshgrid(
|
||||
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
||||
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
||||
indexing='ij'
|
||||
)
|
||||
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||
|
||||
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
||||
|
||||
zs = torch.ones_like(i) # [B, HxW]
|
||||
xs = (i - cx) / fx * zs
|
||||
ys = (j - cy) / fy * zs
|
||||
zs = zs.expand_as(ys)
|
||||
|
||||
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
||||
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
||||
|
||||
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
||||
rays_o = c2w[..., :3, 3] # B, V, 3
|
||||
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
||||
# c2w @ dirctions
|
||||
rays_dxo = torch.cross(rays_o, rays_d)
|
||||
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
||||
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
||||
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
||||
return plucker
|
||||
|
||||
def get_camera_motion(angle, T, speed, n=81):
|
||||
def compute_R_form_rad_angle(angles):
|
||||
theta_x, theta_y, theta_z = angles
|
||||
Rx = np.array([[1, 0, 0],
|
||||
[0, np.cos(theta_x), -np.sin(theta_x)],
|
||||
[0, np.sin(theta_x), np.cos(theta_x)]])
|
||||
|
||||
Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)],
|
||||
[0, 1, 0],
|
||||
[-np.sin(theta_y), 0, np.cos(theta_y)]])
|
||||
|
||||
Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0],
|
||||
[np.sin(theta_z), np.cos(theta_z), 0],
|
||||
[0, 0, 1]])
|
||||
|
||||
R = np.dot(Rz, np.dot(Ry, Rx))
|
||||
return R
|
||||
RT = []
|
||||
for i in range(n):
|
||||
_angle = (i/n)*speed*(CAMERA_DICT["base_angle"])*angle
|
||||
R = compute_R_form_rad_angle(_angle)
|
||||
_T=(i/n)*speed*(CAMERA_DICT["base_T_norm"])*(T.reshape(3,1))
|
||||
_RT = np.concatenate([R,_T], axis=1)
|
||||
RT.append(_RT)
|
||||
RT = np.stack(RT)
|
||||
return RT
|
||||
|
||||
class WanCameraEmbedding:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}),
|
||||
},
|
||||
"optional":{
|
||||
"speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}),
|
||||
"fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
|
||||
"fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
|
||||
"cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
|
||||
"cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT")
|
||||
RETURN_NAMES = ("camera_embedding","width","height","length")
|
||||
FUNCTION = "run"
|
||||
CATEGORY = "camera"
|
||||
|
||||
def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5):
|
||||
"""
|
||||
Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
|
||||
Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
|
||||
"""
|
||||
motion_list = [camera_pose]
|
||||
speed = speed
|
||||
angle = np.array(CAMERA_DICT[motion_list[0]]["angle"])
|
||||
T = np.array(CAMERA_DICT[motion_list[0]]["T"])
|
||||
RT = get_camera_motion(angle, T, speed, length)
|
||||
|
||||
trajs=[]
|
||||
for cp in RT.tolist():
|
||||
traj=[fx,fy,cx,cy,0,0]
|
||||
traj.extend(cp[0])
|
||||
traj.extend(cp[1])
|
||||
traj.extend(cp[2])
|
||||
traj.extend([0,0,0,1])
|
||||
trajs.append(traj)
|
||||
|
||||
cam_params = np.array([[float(x) for x in pose] for pose in trajs])
|
||||
cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1)
|
||||
control_camera_video = process_pose_params(cam_params, width=width, height=height)
|
||||
control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device())
|
||||
|
||||
control_camera_video = torch.concat(
|
||||
[
|
||||
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
|
||||
control_camera_video[:, :, 1:]
|
||||
], dim=2
|
||||
).transpose(1, 2)
|
||||
|
||||
# Reshape, transpose, and view into desired shape
|
||||
b, f, c, h, w = control_camera_video.shape
|
||||
control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
||||
control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
||||
|
||||
return (control_camera_video, width, height, length)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanCameraEmbedding": WanCameraEmbedding,
|
||||
}
|
||||
@@ -31,6 +31,7 @@ class T5TokenizerOptions:
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "_for_testing/conditioning"
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "set_options"
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ class HunyuanImageToVideo:
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
|
||||
"guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
|
||||
},
|
||||
"optional": {"start_image": ("IMAGE", ),
|
||||
}}
|
||||
@@ -101,10 +101,12 @@ class HunyuanImageToVideo:
|
||||
|
||||
if guidance_type == "v1 (concat)":
|
||||
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
|
||||
else:
|
||||
elif guidance_type == "v2 (replace)":
|
||||
cond = {'guiding_frame_index': 0}
|
||||
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
|
||||
out_latent["noise_mask"] = mask
|
||||
elif guidance_type == "custom":
|
||||
cond = {"ref_latent": concat_latent_image}
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, cond)
|
||||
|
||||
|
||||
@@ -10,6 +10,10 @@ from PIL.PngImagePlugin import PngInfo
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
from inspect import cleandoc
|
||||
import torch
|
||||
|
||||
from comfy.comfy_types import FileLocator
|
||||
|
||||
@@ -71,6 +75,24 @@ class ImageFromBatch:
|
||||
s = s_in[batch_index:batch_index + length].clone()
|
||||
return (s,)
|
||||
|
||||
|
||||
class ImageAddNoise:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
|
||||
"strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "repeat"
|
||||
|
||||
CATEGORY = "image"
|
||||
|
||||
def repeat(self, image, seed, strength):
|
||||
generator = torch.manual_seed(seed)
|
||||
s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
|
||||
return (s,)
|
||||
|
||||
class SaveAnimatedWEBP:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
@@ -190,10 +212,110 @@ class SaveAnimatedPNG:
|
||||
|
||||
return { "ui": { "images": results, "animated": (True,)} }
|
||||
|
||||
class SVG:
|
||||
"""
|
||||
Stores SVG representations via a list of BytesIO objects.
|
||||
"""
|
||||
def __init__(self, data: list[BytesIO]):
|
||||
self.data = data
|
||||
|
||||
def combine(self, other: 'SVG') -> 'SVG':
|
||||
return SVG(self.data + other.data)
|
||||
|
||||
@staticmethod
|
||||
def combine_all(svgs: list['SVG']) -> 'SVG':
|
||||
all_svgs_list: list[BytesIO] = []
|
||||
for svg_item in svgs:
|
||||
all_svgs_list.extend(svg_item.data)
|
||||
return SVG(all_svgs_list)
|
||||
|
||||
class SaveSVGNode:
|
||||
"""
|
||||
Save SVG files on disk.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
RETURN_TYPES = ()
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "save_svg"
|
||||
CATEGORY = "image/save" # Changed
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"svg": ("SVG",), # Changed
|
||||
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO"
|
||||
}
|
||||
}
|
||||
|
||||
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results = list()
|
||||
|
||||
# Prepare metadata JSON
|
||||
metadata_dict = {}
|
||||
if prompt is not None:
|
||||
metadata_dict["prompt"] = prompt
|
||||
if extra_pnginfo is not None:
|
||||
metadata_dict.update(extra_pnginfo)
|
||||
|
||||
# Convert metadata to JSON string
|
||||
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
|
||||
|
||||
for batch_number, svg_bytes in enumerate(svg.data):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.svg"
|
||||
|
||||
# Read SVG content
|
||||
svg_bytes.seek(0)
|
||||
svg_content = svg_bytes.read().decode('utf-8')
|
||||
|
||||
# Inject metadata if available
|
||||
if metadata_json:
|
||||
# Create metadata element with CDATA section
|
||||
metadata_element = f""" <metadata>
|
||||
<![CDATA[
|
||||
{metadata_json}
|
||||
]]>
|
||||
</metadata>
|
||||
"""
|
||||
# Insert metadata after opening svg tag using regex with a replacement function
|
||||
def replacement(match):
|
||||
# match.group(1) contains the captured <svg> tag
|
||||
return match.group(1) + '\n' + metadata_element
|
||||
|
||||
# Apply the substitution
|
||||
svg_content = re.sub(r'(<svg[^>]*>)', replacement, svg_content, flags=re.UNICODE)
|
||||
|
||||
# Write the modified SVG to file
|
||||
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
|
||||
svg_file.write(svg_content.encode('utf-8'))
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
return { "ui": { "images": results } }
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
"ImageFromBatch": ImageFromBatch,
|
||||
"ImageAddNoise": ImageAddNoise,
|
||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||
"SaveSVGNode": SaveSVGNode,
|
||||
}
|
||||
|
||||
@@ -2,6 +2,10 @@ import nodes
|
||||
import folder_paths
|
||||
import os
|
||||
|
||||
from comfy.comfy_types import IO
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
|
||||
|
||||
def normalize_path(path):
|
||||
return path.replace('\\', '/')
|
||||
|
||||
@@ -21,8 +25,8 @@ class Load3D():
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info")
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
@@ -41,7 +45,14 @@ class Load3D():
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info']
|
||||
video = None
|
||||
|
||||
if image['recording'] != "":
|
||||
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
||||
|
||||
video = VideoFromFile(recording_video_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
|
||||
|
||||
class Load3DAnimation():
|
||||
@classmethod
|
||||
@@ -59,8 +70,8 @@ class Load3DAnimation():
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info")
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
@@ -77,7 +88,14 @@ class Load3DAnimation():
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, image['camera_info']
|
||||
video = None
|
||||
|
||||
if image['recording'] != "":
|
||||
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
||||
|
||||
video = VideoFromFile(recording_video_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, image['camera_info'], video
|
||||
|
||||
class Preview3D():
|
||||
@classmethod
|
||||
|
||||
323
comfy_extras/nodes_string.py
Normal file
323
comfy_extras/nodes_string.py
Normal file
@@ -0,0 +1,323 @@
|
||||
import re
|
||||
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
|
||||
class StringConcatenate():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string_a": (IO.STRING, {"multiline": True}),
|
||||
"string_b": (IO.STRING, {"multiline": True}),
|
||||
"delimiter": (IO.STRING, {"multiline": False, "default": ""})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string_a, string_b, delimiter, **kwargs):
|
||||
return delimiter.join((string_a, string_b)),
|
||||
|
||||
class StringSubstring():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"start": (IO.INT, {}),
|
||||
"end": (IO.INT, {}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, start, end, **kwargs):
|
||||
return string[start:end],
|
||||
|
||||
class StringLength():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.INT,)
|
||||
RETURN_NAMES = ("length",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, **kwargs):
|
||||
length = len(string)
|
||||
|
||||
return length,
|
||||
|
||||
class CaseConverter():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["UPPERCASE", "lowercase", "Capitalize", "Title Case"]})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, mode, **kwargs):
|
||||
if mode == "UPPERCASE":
|
||||
result = string.upper()
|
||||
elif mode == "lowercase":
|
||||
result = string.lower()
|
||||
elif mode == "Capitalize":
|
||||
result = string.capitalize()
|
||||
elif mode == "Title Case":
|
||||
result = string.title()
|
||||
else:
|
||||
result = string
|
||||
|
||||
return result,
|
||||
|
||||
|
||||
class StringTrim():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["Both", "Left", "Right"]})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, mode, **kwargs):
|
||||
if mode == "Both":
|
||||
result = string.strip()
|
||||
elif mode == "Left":
|
||||
result = string.lstrip()
|
||||
elif mode == "Right":
|
||||
result = string.rstrip()
|
||||
else:
|
||||
result = string
|
||||
|
||||
return result,
|
||||
|
||||
class StringReplace():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"find": (IO.STRING, {"multiline": True}),
|
||||
"replace": (IO.STRING, {"multiline": True})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, find, replace, **kwargs):
|
||||
result = string.replace(find, replace)
|
||||
return result,
|
||||
|
||||
|
||||
class StringContains():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"substring": (IO.STRING, {"multiline": True}),
|
||||
"case_sensitive": (IO.BOOLEAN, {"default": True})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
RETURN_NAMES = ("contains",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, substring, case_sensitive, **kwargs):
|
||||
if case_sensitive:
|
||||
contains = substring in string
|
||||
else:
|
||||
contains = substring.lower() in string.lower()
|
||||
|
||||
return contains,
|
||||
|
||||
|
||||
class StringCompare():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string_a": (IO.STRING, {"multiline": True}),
|
||||
"string_b": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["Starts With", "Ends With", "Equal"]}),
|
||||
"case_sensitive": (IO.BOOLEAN, {"default": True})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string_a, string_b, mode, case_sensitive, **kwargs):
|
||||
if case_sensitive:
|
||||
a = string_a
|
||||
b = string_b
|
||||
else:
|
||||
a = string_a.lower()
|
||||
b = string_b.lower()
|
||||
|
||||
if mode == "Equal":
|
||||
return a == b,
|
||||
elif mode == "Starts With":
|
||||
return a.startswith(b),
|
||||
elif mode == "Ends With":
|
||||
return a.endswith(b),
|
||||
|
||||
class RegexMatch():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
RETURN_NAMES = ("matches",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, **kwargs):
|
||||
flags = 0
|
||||
|
||||
if case_insensitive:
|
||||
flags |= re.IGNORECASE
|
||||
if multiline:
|
||||
flags |= re.MULTILINE
|
||||
if dotall:
|
||||
flags |= re.DOTALL
|
||||
|
||||
try:
|
||||
match = re.search(regex_pattern, string, flags)
|
||||
result = match is not None
|
||||
|
||||
except re.error:
|
||||
result = False
|
||||
|
||||
return result,
|
||||
|
||||
|
||||
class RegexExtract():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["First Match", "All Matches", "First Group", "All Groups"]}),
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False}),
|
||||
"group_index": (IO.INT, {"default": 1, "min": 0, "max": 100})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index, **kwargs):
|
||||
join_delimiter = "\n"
|
||||
|
||||
flags = 0
|
||||
if case_insensitive:
|
||||
flags |= re.IGNORECASE
|
||||
if multiline:
|
||||
flags |= re.MULTILINE
|
||||
if dotall:
|
||||
flags |= re.DOTALL
|
||||
|
||||
try:
|
||||
if mode == "First Match":
|
||||
match = re.search(regex_pattern, string, flags)
|
||||
if match:
|
||||
result = match.group(0)
|
||||
else:
|
||||
result = ""
|
||||
|
||||
elif mode == "All Matches":
|
||||
matches = re.findall(regex_pattern, string, flags)
|
||||
if matches:
|
||||
if isinstance(matches[0], tuple):
|
||||
result = join_delimiter.join([m[0] for m in matches])
|
||||
else:
|
||||
result = join_delimiter.join(matches)
|
||||
else:
|
||||
result = ""
|
||||
|
||||
elif mode == "First Group":
|
||||
match = re.search(regex_pattern, string, flags)
|
||||
if match and len(match.groups()) >= group_index:
|
||||
result = match.group(group_index)
|
||||
else:
|
||||
result = ""
|
||||
|
||||
elif mode == "All Groups":
|
||||
matches = re.finditer(regex_pattern, string, flags)
|
||||
results = []
|
||||
for match in matches:
|
||||
if match.groups() and len(match.groups()) >= group_index:
|
||||
results.append(match.group(group_index))
|
||||
result = join_delimiter.join(results)
|
||||
else:
|
||||
result = ""
|
||||
|
||||
except re.error:
|
||||
result = ""
|
||||
|
||||
return result,
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StringConcatenate": StringConcatenate,
|
||||
"StringSubstring": StringSubstring,
|
||||
"StringLength": StringLength,
|
||||
"CaseConverter": CaseConverter,
|
||||
"StringTrim": StringTrim,
|
||||
"StringReplace": StringReplace,
|
||||
"StringContains": StringContains,
|
||||
"StringCompare": StringCompare,
|
||||
"RegexMatch": RegexMatch,
|
||||
"RegexExtract": RegexExtract
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StringConcatenate": "Concatenate",
|
||||
"StringSubstring": "Substring",
|
||||
"StringLength": "Length",
|
||||
"CaseConverter": "Case Converter",
|
||||
"StringTrim": "Trim",
|
||||
"StringReplace": "Replace",
|
||||
"StringContains": "Contains",
|
||||
"StringCompare": "Compare",
|
||||
"RegexMatch": "Regex Match",
|
||||
"RegexExtract": "Regex Extract"
|
||||
}
|
||||
@@ -297,6 +297,52 @@ class TrimVideoLatent:
|
||||
samples_out["samples"] = s1[:, :, trim_amount:]
|
||||
return (samples_out,)
|
||||
|
||||
class WanCameraImageToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"camera_conditions": ("WAN_CAMERA_EMBEDDING", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
||||
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
|
||||
|
||||
if camera_conditions is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions})
|
||||
negative = node_helpers.conditioning_set_values(negative, {'camera_conditions': camera_conditions})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanImageToVideo": WanImageToVideo,
|
||||
@@ -305,4 +351,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
||||
"WanVaceToVideo": WanVaceToVideo,
|
||||
"TrimVideoLatent": TrimVideoLatent,
|
||||
"WanCameraImageToVideo": WanCameraImageToVideo,
|
||||
}
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.31"
|
||||
__version__ = "0.3.34"
|
||||
|
||||
@@ -146,6 +146,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
input_data_all[x] = [unique_id]
|
||||
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
return input_data_all, missing_keys
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
28
fix_torch.py
28
fix_torch.py
@@ -1,28 +0,0 @@
|
||||
import importlib.util
|
||||
import shutil
|
||||
import os
|
||||
import ctypes
|
||||
import logging
|
||||
|
||||
|
||||
def fix_pytorch_libomp():
|
||||
"""
|
||||
Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed.
|
||||
"""
|
||||
torch_spec = importlib.util.find_spec("torch")
|
||||
for folder in torch_spec.submodule_search_locations:
|
||||
lib_folder = os.path.join(folder, "lib")
|
||||
test_file = os.path.join(lib_folder, "fbgemm.dll")
|
||||
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
|
||||
if os.path.exists(dest):
|
||||
break
|
||||
|
||||
with open(test_file, "rb") as f:
|
||||
contents = f.read()
|
||||
if b"libomp140.x86_64.dll" not in contents:
|
||||
break
|
||||
try:
|
||||
ctypes.cdll.LoadLibrary(test_file)
|
||||
except FileNotFoundError:
|
||||
logging.warning("Detected pytorch version with libomp issue, patching.")
|
||||
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
|
||||
9
main.py
9
main.py
@@ -125,13 +125,6 @@ if __name__ == "__main__":
|
||||
|
||||
import cuda_malloc
|
||||
|
||||
if args.windows_standalone_build:
|
||||
try:
|
||||
from fix_torch import fix_pytorch_libomp
|
||||
fix_pytorch_libomp()
|
||||
except:
|
||||
pass
|
||||
|
||||
import comfy.utils
|
||||
|
||||
import execution
|
||||
@@ -270,7 +263,7 @@ def start_comfyui(asyncio_loop=None):
|
||||
q = execution.PromptQueue(prompt_server)
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
|
||||
44
nodes.py
44
nodes.py
@@ -246,6 +246,9 @@ class ConditioningZeroOut:
|
||||
pooled_output = d.get("pooled_output", None)
|
||||
if pooled_output is not None:
|
||||
d["pooled_output"] = torch.zeros_like(pooled_output)
|
||||
conditioning_lyrics = d.get("conditioning_lyrics", None)
|
||||
if conditioning_lyrics is not None:
|
||||
d["conditioning_lyrics"] = torch.zeros_like(conditioning_lyrics)
|
||||
n = [torch.zeros_like(t[0]), d]
|
||||
c.append(n)
|
||||
return (c, )
|
||||
@@ -917,7 +920,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@@ -1937,7 +1940,7 @@ class ImagePadForOutpaint:
|
||||
|
||||
mask[top:top + d2, left:left + d3] = t
|
||||
|
||||
return (new_image, mask)
|
||||
return (new_image, mask.unsqueeze(0))
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
@@ -2258,9 +2261,22 @@ def init_builtin_extra_nodes():
|
||||
"nodes_optimalsteps.py",
|
||||
"nodes_hidream.py",
|
||||
"nodes_fresca.py",
|
||||
"nodes_apg.py",
|
||||
"nodes_preview_any.py",
|
||||
"nodes_ace.py",
|
||||
"nodes_string.py",
|
||||
"nodes_camera_trajectory.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
for node_file in extras_files:
|
||||
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
||||
import_failed.append(node_file)
|
||||
|
||||
return import_failed
|
||||
|
||||
|
||||
def init_builtin_api_nodes():
|
||||
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
|
||||
api_nodes_files = [
|
||||
"nodes_ideogram.py",
|
||||
@@ -2276,11 +2292,10 @@ def init_builtin_extra_nodes():
|
||||
"nodes_pika.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
for node_file in extras_files:
|
||||
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
||||
import_failed.append(node_file)
|
||||
if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
|
||||
return api_nodes_files
|
||||
|
||||
import_failed = []
|
||||
for node_file in api_nodes_files:
|
||||
if not load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
|
||||
import_failed.append(node_file)
|
||||
@@ -2288,14 +2303,29 @@ def init_builtin_extra_nodes():
|
||||
return import_failed
|
||||
|
||||
|
||||
def init_extra_nodes(init_custom_nodes=True):
|
||||
def init_extra_nodes(init_custom_nodes=True, init_api_nodes=True):
|
||||
import_failed = init_builtin_extra_nodes()
|
||||
|
||||
import_failed_api = []
|
||||
if init_api_nodes:
|
||||
import_failed_api = init_builtin_api_nodes()
|
||||
|
||||
if init_custom_nodes:
|
||||
init_external_custom_nodes()
|
||||
else:
|
||||
logging.info("Skipping loading of custom nodes")
|
||||
|
||||
if len(import_failed_api) > 0:
|
||||
logging.warning("WARNING: some comfy_api_nodes/ nodes did not import correctly. This may be because they are missing some dependencies.\n")
|
||||
for node in import_failed_api:
|
||||
logging.warning("IMPORT FAILED: {}".format(node))
|
||||
logging.warning("\nThis issue might be caused by new missing dependencies added the last time you updated ComfyUI.")
|
||||
if args.windows_standalone_build:
|
||||
logging.warning("Please run the update script: update/update_comfyui.bat")
|
||||
else:
|
||||
logging.warning("Please do a: pip install -r requirements.txt")
|
||||
logging.warning("")
|
||||
|
||||
if len(import_failed) > 0:
|
||||
logging.warning("WARNING: some comfy_extras/ nodes did not import correctly. This may be because they are missing some dependencies.\n")
|
||||
for node in import_failed:
|
||||
|
||||
904
openapi.yaml
Normal file
904
openapi.yaml
Normal file
@@ -0,0 +1,904 @@
|
||||
openapi: 3.0.3
|
||||
info:
|
||||
title: ComfyUI API
|
||||
description: |
|
||||
API for ComfyUI - A powerful and modular UI for Stable Diffusion.
|
||||
|
||||
This API allows you to interact with ComfyUI programmatically, including:
|
||||
- Submitting workflows for execution
|
||||
- Managing the execution queue
|
||||
- Retrieving generated images
|
||||
- Managing models
|
||||
- Retrieving node information
|
||||
version: 1.0.0
|
||||
license:
|
||||
name: GNU General Public License v3.0
|
||||
url: https://github.com/comfyanonymous/ComfyUI/blob/master/LICENSE
|
||||
|
||||
servers:
|
||||
- url: /
|
||||
description: Default ComfyUI server
|
||||
|
||||
tags:
|
||||
- name: workflow
|
||||
description: Workflow execution and management
|
||||
- name: queue
|
||||
description: Queue management
|
||||
- name: image
|
||||
description: Image handling
|
||||
- name: node
|
||||
description: Node information
|
||||
- name: model
|
||||
description: Model management
|
||||
- name: system
|
||||
description: System information
|
||||
- name: internal
|
||||
description: Internal API routes
|
||||
|
||||
paths:
|
||||
/prompt:
|
||||
get:
|
||||
tags:
|
||||
- workflow
|
||||
summary: Get information about current prompt execution
|
||||
description: Returns information about the current prompt in the execution queue
|
||||
operationId: getPromptInfo
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/PromptInfo'
|
||||
post:
|
||||
tags:
|
||||
- workflow
|
||||
summary: Submit a workflow for execution
|
||||
description: |
|
||||
Submit a workflow to be executed by the backend.
|
||||
The workflow is a JSON object describing the nodes and their connections.
|
||||
operationId: executePrompt
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/PromptRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: Success - Prompt accepted
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/PromptResponse'
|
||||
'400':
|
||||
description: Invalid prompt
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
|
||||
/queue:
|
||||
get:
|
||||
tags:
|
||||
- queue
|
||||
summary: Get queue information
|
||||
description: Returns information about running and pending items in the queue
|
||||
operationId: getQueueInfo
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/QueueInfo'
|
||||
post:
|
||||
tags:
|
||||
- queue
|
||||
summary: Manage queue
|
||||
description: Clear the queue or delete specific items
|
||||
operationId: manageQueue
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
clear:
|
||||
type: boolean
|
||||
description: If true, clears the entire queue
|
||||
delete:
|
||||
type: array
|
||||
description: Array of prompt IDs to delete from the queue
|
||||
items:
|
||||
type: string
|
||||
format: uuid
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
|
||||
/interrupt:
|
||||
post:
|
||||
tags:
|
||||
- workflow
|
||||
summary: Interrupt the current execution
|
||||
description: Interrupts the currently running workflow execution
|
||||
operationId: interruptExecution
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
|
||||
/free:
|
||||
post:
|
||||
tags:
|
||||
- system
|
||||
summary: Free resources
|
||||
description: Unload models and/or free memory
|
||||
operationId: freeResources
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
unload_models:
|
||||
type: boolean
|
||||
description: If true, unloads models from memory
|
||||
free_memory:
|
||||
type: boolean
|
||||
description: If true, frees GPU memory
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
|
||||
/history:
|
||||
get:
|
||||
tags:
|
||||
- workflow
|
||||
summary: Get execution history
|
||||
description: Returns the history of executed workflows
|
||||
operationId: getHistory
|
||||
parameters:
|
||||
- name: max_items
|
||||
in: query
|
||||
description: Maximum number of history items to return
|
||||
required: false
|
||||
schema:
|
||||
type: integer
|
||||
format: int32
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/HistoryItem'
|
||||
post:
|
||||
tags:
|
||||
- workflow
|
||||
summary: Manage history
|
||||
description: Clear history or delete specific items
|
||||
operationId: manageHistory
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
clear:
|
||||
type: boolean
|
||||
description: If true, clears the entire history
|
||||
delete:
|
||||
type: array
|
||||
description: Array of prompt IDs to delete from history
|
||||
items:
|
||||
type: string
|
||||
format: uuid
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
|
||||
/history/{prompt_id}:
|
||||
get:
|
||||
tags:
|
||||
- workflow
|
||||
summary: Get specific history item
|
||||
description: Returns a specific history item by ID
|
||||
operationId: getHistoryItem
|
||||
parameters:
|
||||
- name: prompt_id
|
||||
in: path
|
||||
description: ID of the prompt to retrieve
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
format: uuid
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HistoryItem'
|
||||
|
||||
/object_info:
|
||||
get:
|
||||
tags:
|
||||
- node
|
||||
summary: Get all node information
|
||||
description: Returns information about all available nodes
|
||||
operationId: getNodeInfo
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
additionalProperties:
|
||||
$ref: '#/components/schemas/NodeInfo'
|
||||
|
||||
/object_info/{node_class}:
|
||||
get:
|
||||
tags:
|
||||
- node
|
||||
summary: Get specific node information
|
||||
description: Returns information about a specific node class
|
||||
operationId: getNodeClassInfo
|
||||
parameters:
|
||||
- name: node_class
|
||||
in: path
|
||||
description: Name of the node class
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
additionalProperties:
|
||||
$ref: '#/components/schemas/NodeInfo'
|
||||
|
||||
/upload/image:
|
||||
post:
|
||||
tags:
|
||||
- image
|
||||
summary: Upload an image
|
||||
description: Uploads an image to the server
|
||||
operationId: uploadImage
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
multipart/form-data:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
image:
|
||||
type: string
|
||||
format: binary
|
||||
description: The image file to upload
|
||||
overwrite:
|
||||
type: string
|
||||
description: Whether to overwrite if file exists (true/false)
|
||||
type:
|
||||
type: string
|
||||
enum: [input, temp, output]
|
||||
description: Type of directory to store the image in
|
||||
subfolder:
|
||||
type: string
|
||||
description: Subfolder to store the image in
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Filename of the uploaded image
|
||||
subfolder:
|
||||
type: string
|
||||
description: Subfolder the image was stored in
|
||||
type:
|
||||
type: string
|
||||
description: Type of directory the image was stored in
|
||||
'400':
|
||||
description: Bad request
|
||||
|
||||
/upload/mask:
|
||||
post:
|
||||
tags:
|
||||
- image
|
||||
summary: Upload a mask for an image
|
||||
description: Uploads a mask image and applies it to a referenced original image
|
||||
operationId: uploadMask
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
multipart/form-data:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
image:
|
||||
type: string
|
||||
format: binary
|
||||
description: The mask image file to upload
|
||||
original_ref:
|
||||
type: string
|
||||
description: JSON string containing reference to the original image
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Filename of the uploaded mask
|
||||
subfolder:
|
||||
type: string
|
||||
description: Subfolder the mask was stored in
|
||||
type:
|
||||
type: string
|
||||
description: Type of directory the mask was stored in
|
||||
'400':
|
||||
description: Bad request
|
||||
|
||||
/view:
|
||||
get:
|
||||
tags:
|
||||
- image
|
||||
summary: View an image
|
||||
description: Retrieves an image from the server
|
||||
operationId: viewImage
|
||||
parameters:
|
||||
- name: filename
|
||||
in: query
|
||||
description: Name of the file to retrieve
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: type
|
||||
in: query
|
||||
description: Type of directory to retrieve from
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
enum: [input, temp, output]
|
||||
default: output
|
||||
- name: subfolder
|
||||
in: query
|
||||
description: Subfolder to retrieve from
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
- name: preview
|
||||
in: query
|
||||
description: Preview options (format;quality)
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
- name: channel
|
||||
in: query
|
||||
description: Channel to retrieve (rgb, a, rgba)
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
enum: [rgb, a, rgba]
|
||||
default: rgba
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
image/*:
|
||||
schema:
|
||||
type: string
|
||||
format: binary
|
||||
'400':
|
||||
description: Bad request
|
||||
'404':
|
||||
description: File not found
|
||||
|
||||
/view_metadata/{folder_name}:
|
||||
get:
|
||||
tags:
|
||||
- model
|
||||
summary: View model metadata
|
||||
description: Retrieves metadata from a safetensors file
|
||||
operationId: viewModelMetadata
|
||||
parameters:
|
||||
- name: folder_name
|
||||
in: path
|
||||
description: Name of the model folder
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: filename
|
||||
in: query
|
||||
description: Name of the safetensors file
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
'404':
|
||||
description: File not found
|
||||
|
||||
/models:
|
||||
get:
|
||||
tags:
|
||||
- model
|
||||
summary: Get model types
|
||||
description: Returns a list of available model types
|
||||
operationId: getModelTypes
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
||||
/models/{folder}:
|
||||
get:
|
||||
tags:
|
||||
- model
|
||||
summary: Get models of a specific type
|
||||
description: Returns a list of available models of a specific type
|
||||
operationId: getModels
|
||||
parameters:
|
||||
- name: folder
|
||||
in: path
|
||||
description: Model type folder
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
'404':
|
||||
description: Folder not found
|
||||
|
||||
/embeddings:
|
||||
get:
|
||||
tags:
|
||||
- model
|
||||
summary: Get embeddings
|
||||
description: Returns a list of available embeddings
|
||||
operationId: getEmbeddings
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
||||
/extensions:
|
||||
get:
|
||||
tags:
|
||||
- system
|
||||
summary: Get extensions
|
||||
description: Returns a list of available extensions
|
||||
operationId: getExtensions
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
||||
/system_stats:
|
||||
get:
|
||||
tags:
|
||||
- system
|
||||
summary: Get system statistics
|
||||
description: Returns system information including RAM, VRAM, and ComfyUI version
|
||||
operationId: getSystemStats
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/SystemStats'
|
||||
|
||||
/ws:
|
||||
get:
|
||||
tags:
|
||||
- workflow
|
||||
summary: WebSocket connection
|
||||
description: |
|
||||
Establishes a WebSocket connection for real-time communication.
|
||||
This endpoint is used for receiving progress updates, status changes, and results from workflow executions.
|
||||
operationId: webSocketConnect
|
||||
parameters:
|
||||
- name: clientId
|
||||
in: query
|
||||
description: Optional client ID for reconnection
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'101':
|
||||
description: Switching Protocols to WebSocket
|
||||
|
||||
/internal/logs:
|
||||
get:
|
||||
tags:
|
||||
- internal
|
||||
summary: Get logs
|
||||
description: Returns system logs as a single string
|
||||
operationId: getLogs
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: string
|
||||
|
||||
/internal/logs/raw:
|
||||
get:
|
||||
tags:
|
||||
- internal
|
||||
summary: Get raw logs
|
||||
description: Returns raw system logs with terminal size information
|
||||
operationId: getRawLogs
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
entries:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
t:
|
||||
type: string
|
||||
description: Timestamp
|
||||
m:
|
||||
type: string
|
||||
description: Message
|
||||
size:
|
||||
type: object
|
||||
properties:
|
||||
cols:
|
||||
type: integer
|
||||
description: Terminal columns
|
||||
rows:
|
||||
type: integer
|
||||
description: Terminal rows
|
||||
|
||||
/internal/logs/subscribe:
|
||||
patch:
|
||||
tags:
|
||||
- internal
|
||||
summary: Subscribe to logs
|
||||
description: Subscribe or unsubscribe to log updates
|
||||
operationId: subscribeToLogs
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
clientId:
|
||||
type: string
|
||||
description: Client ID
|
||||
enabled:
|
||||
type: boolean
|
||||
description: Whether to enable or disable subscription
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
|
||||
/internal/folder_paths:
|
||||
get:
|
||||
tags:
|
||||
- internal
|
||||
summary: Get folder paths
|
||||
description: Returns a map of folder names to their paths
|
||||
operationId: getFolderPaths
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
|
||||
/internal/files/{directory_type}:
|
||||
get:
|
||||
tags:
|
||||
- internal
|
||||
summary: Get files
|
||||
description: Returns a list of files in a specific directory type
|
||||
operationId: getFiles
|
||||
parameters:
|
||||
- name: directory_type
|
||||
in: path
|
||||
description: Type of directory (output, input, temp)
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
enum: [output, input, temp]
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
'400':
|
||||
description: Invalid directory type
|
||||
|
||||
components:
|
||||
schemas:
|
||||
PromptRequest:
|
||||
type: object
|
||||
required:
|
||||
- prompt
|
||||
properties:
|
||||
prompt:
|
||||
type: object
|
||||
description: The workflow graph to execute
|
||||
additionalProperties: true
|
||||
number:
|
||||
type: number
|
||||
description: Priority number for the queue (lower numbers have higher priority)
|
||||
front:
|
||||
type: boolean
|
||||
description: If true, adds the prompt to the front of the queue
|
||||
extra_data:
|
||||
type: object
|
||||
description: Extra data to be associated with the prompt
|
||||
additionalProperties: true
|
||||
client_id:
|
||||
type: string
|
||||
description: Client ID for attribution of the prompt
|
||||
|
||||
PromptResponse:
|
||||
type: object
|
||||
properties:
|
||||
prompt_id:
|
||||
type: string
|
||||
format: uuid
|
||||
description: Unique identifier for the prompt execution
|
||||
number:
|
||||
type: number
|
||||
description: Priority number in the queue
|
||||
node_errors:
|
||||
type: object
|
||||
description: Any errors in the nodes of the prompt
|
||||
additionalProperties: true
|
||||
|
||||
ErrorResponse:
|
||||
type: object
|
||||
properties:
|
||||
error:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
description: Error type
|
||||
message:
|
||||
type: string
|
||||
description: Error message
|
||||
details:
|
||||
type: string
|
||||
description: Detailed error information
|
||||
extra_info:
|
||||
type: object
|
||||
description: Additional error information
|
||||
additionalProperties: true
|
||||
node_errors:
|
||||
type: object
|
||||
description: Node-specific errors
|
||||
additionalProperties: true
|
||||
|
||||
PromptInfo:
|
||||
type: object
|
||||
properties:
|
||||
exec_info:
|
||||
type: object
|
||||
properties:
|
||||
queue_remaining:
|
||||
type: integer
|
||||
description: Number of items remaining in the queue
|
||||
|
||||
QueueInfo:
|
||||
type: object
|
||||
properties:
|
||||
queue_running:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
description: Currently running items
|
||||
additionalProperties: true
|
||||
queue_pending:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
description: Pending items in the queue
|
||||
additionalProperties: true
|
||||
|
||||
HistoryItem:
|
||||
type: object
|
||||
properties:
|
||||
prompt_id:
|
||||
type: string
|
||||
format: uuid
|
||||
description: Unique identifier for the prompt
|
||||
prompt:
|
||||
type: object
|
||||
description: The workflow graph that was executed
|
||||
additionalProperties: true
|
||||
extra_data:
|
||||
type: object
|
||||
description: Additional data associated with the execution
|
||||
additionalProperties: true
|
||||
outputs:
|
||||
type: object
|
||||
description: Output data from the execution
|
||||
additionalProperties: true
|
||||
|
||||
NodeInfo:
|
||||
type: object
|
||||
properties:
|
||||
input:
|
||||
type: object
|
||||
description: Input specifications for the node
|
||||
additionalProperties: true
|
||||
input_order:
|
||||
type: object
|
||||
description: Order of inputs for display
|
||||
additionalProperties:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
output:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: Output types of the node
|
||||
output_is_list:
|
||||
type: array
|
||||
items:
|
||||
type: boolean
|
||||
description: Whether each output is a list
|
||||
output_name:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: Names of the outputs
|
||||
name:
|
||||
type: string
|
||||
description: Internal name of the node
|
||||
display_name:
|
||||
type: string
|
||||
description: Display name of the node
|
||||
description:
|
||||
type: string
|
||||
description: Description of the node
|
||||
python_module:
|
||||
type: string
|
||||
description: Python module implementing the node
|
||||
category:
|
||||
type: string
|
||||
description: Category of the node
|
||||
output_node:
|
||||
type: boolean
|
||||
description: Whether this is an output node
|
||||
output_tooltips:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: Tooltips for outputs
|
||||
deprecated:
|
||||
type: boolean
|
||||
description: Whether the node is deprecated
|
||||
experimental:
|
||||
type: boolean
|
||||
description: Whether the node is experimental
|
||||
api_node:
|
||||
type: boolean
|
||||
description: Whether this is an API node
|
||||
|
||||
SystemStats:
|
||||
type: object
|
||||
properties:
|
||||
system:
|
||||
type: object
|
||||
properties:
|
||||
os:
|
||||
type: string
|
||||
description: Operating system
|
||||
ram_total:
|
||||
type: number
|
||||
description: Total system RAM in bytes
|
||||
ram_free:
|
||||
type: number
|
||||
description: Free system RAM in bytes
|
||||
comfyui_version:
|
||||
type: string
|
||||
description: ComfyUI version
|
||||
python_version:
|
||||
type: string
|
||||
description: Python version
|
||||
pytorch_version:
|
||||
type: string
|
||||
description: PyTorch version
|
||||
embedded_python:
|
||||
type: boolean
|
||||
description: Whether using embedded Python
|
||||
argv:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: Command line arguments
|
||||
devices:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Device name
|
||||
type:
|
||||
type: string
|
||||
description: Device type
|
||||
index:
|
||||
type: integer
|
||||
description: Device index
|
||||
vram_total:
|
||||
type: number
|
||||
description: Total VRAM in bytes
|
||||
vram_free:
|
||||
type: number
|
||||
description: Free VRAM in bytes
|
||||
torch_vram_total:
|
||||
type: number
|
||||
description: Total VRAM as reported by PyTorch
|
||||
torch_vram_free:
|
||||
type: number
|
||||
description: Free VRAM as reported by PyTorch
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.31"
|
||||
version = "0.3.34"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.18.9
|
||||
comfyui-workflow-templates==0.1.11
|
||||
comfyui-frontend-package==1.19.9
|
||||
comfyui-workflow-templates==0.1.14
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
||||
@@ -101,6 +101,14 @@ prompt_text = """
|
||||
|
||||
def queue_prompt(prompt):
|
||||
p = {"prompt": prompt}
|
||||
|
||||
# If the workflow contains API nodes, you can add a Comfy API key to the `extra_data`` field of the payload.
|
||||
# p["extra_data"] = {
|
||||
# "api_key_comfy_org": "comfyui-87d01e28d*******************************************************" # replace with real key
|
||||
# }
|
||||
# See: https://docs.comfy.org/tutorials/api-nodes/overview
|
||||
# Generate a key here: https://platform.comfy.org/login
|
||||
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
req = request.Request("http://127.0.0.1:8188/prompt", data=data)
|
||||
request.urlopen(req)
|
||||
|
||||
15
server.py
15
server.py
@@ -32,12 +32,13 @@ from app.frontend_management import FrontendManager
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
TEXT = 3
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
try:
|
||||
@@ -878,3 +879,15 @@ class PromptServer():
|
||||
logging.warning(traceback.format_exc())
|
||||
|
||||
return json_data
|
||||
|
||||
def send_progress_text(
|
||||
self, text: Union[bytes, bytearray, str], node_id: str, sid=None
|
||||
):
|
||||
if isinstance(text, str):
|
||||
text = text.encode("utf-8")
|
||||
node_id_bytes = str(node_id).encode("utf-8")
|
||||
|
||||
# Pack the node_id length as a 4-byte unsigned integer, followed by the node_id bytes
|
||||
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
|
||||
|
||||
self.send_sync(BinaryEventTypes.TEXT, message, sid)
|
||||
|
||||
74
tests-api/README.md
Normal file
74
tests-api/README.md
Normal file
@@ -0,0 +1,74 @@
|
||||
# ComfyUI API Testing
|
||||
|
||||
This directory contains tests for validating the ComfyUI OpenAPI specification against a running instance of ComfyUI.
|
||||
|
||||
## Setup
|
||||
|
||||
1. Install the required dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Make sure you have a running instance of ComfyUI (default: http://127.0.0.1:8188)
|
||||
|
||||
## Running the Tests
|
||||
|
||||
Run all tests with pytest:
|
||||
|
||||
```bash
|
||||
cd tests-api
|
||||
pytest
|
||||
```
|
||||
|
||||
Run specific test files:
|
||||
|
||||
```bash
|
||||
pytest test_spec_validation.py
|
||||
pytest test_endpoint_existence.py
|
||||
pytest test_schema_validation.py
|
||||
pytest test_api_by_tag.py
|
||||
```
|
||||
|
||||
Run tests with more verbose output:
|
||||
|
||||
```bash
|
||||
pytest -v
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
The tests are organized into several categories:
|
||||
|
||||
1. **Spec Validation**: Validates that the OpenAPI specification is valid.
|
||||
2. **Endpoint Existence**: Tests that the endpoints defined in the spec exist on the server.
|
||||
3. **Schema Validation**: Tests that the server responses match the schemas defined in the spec.
|
||||
4. **Tag-Based Tests**: Tests that the API's tag organization is consistent.
|
||||
|
||||
## Using a Different Server
|
||||
|
||||
By default, the tests connect to `http://127.0.0.1:8188`. To test against a different server, set the `COMFYUI_SERVER_URL` environment variable:
|
||||
|
||||
```bash
|
||||
COMFYUI_SERVER_URL=http://example.com:8188 pytest
|
||||
```
|
||||
|
||||
## Test Structure
|
||||
|
||||
- `conftest.py`: Contains pytest fixtures used by the tests.
|
||||
- `utils/`: Contains utility functions for working with the OpenAPI spec.
|
||||
- `test_*.py`: The actual test files.
|
||||
- `resources/`: Contains resources used by the tests (e.g., sample workflows).
|
||||
|
||||
## Extending the Tests
|
||||
|
||||
To add new tests:
|
||||
|
||||
1. For testing new endpoints, add them to the appropriate test file based on their category.
|
||||
2. For testing more complex functionality, create a new test file following the established patterns.
|
||||
|
||||
## Notes
|
||||
|
||||
- Tests that require a running server will be skipped if the server is not available.
|
||||
- Some tests may fail if the server doesn't match the specification exactly.
|
||||
- The tests don't modify any data on the server (they're read-only).
|
||||
141
tests-api/conftest.py
Normal file
141
tests-api/conftest.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Test fixtures for API testing
|
||||
"""
|
||||
import os
|
||||
import pytest
|
||||
import yaml
|
||||
import requests
|
||||
import logging
|
||||
from typing import Dict, Any, Generator, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default server configuration
|
||||
DEFAULT_SERVER_URL = "http://127.0.0.1:8188"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_spec_path() -> str:
|
||||
"""
|
||||
Get the path to the OpenAPI specification file
|
||||
|
||||
Returns:
|
||||
Path to the OpenAPI specification file
|
||||
"""
|
||||
return os.path.abspath(os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"..",
|
||||
"openapi.yaml"
|
||||
))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_spec(api_spec_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Load the OpenAPI specification
|
||||
|
||||
Args:
|
||||
api_spec_path: Path to the spec file
|
||||
|
||||
Returns:
|
||||
Parsed OpenAPI specification
|
||||
"""
|
||||
with open(api_spec_path, 'r') as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def base_url() -> str:
|
||||
"""
|
||||
Get the base URL for the API server
|
||||
|
||||
Returns:
|
||||
Base URL string
|
||||
"""
|
||||
# Allow overriding via environment variable
|
||||
return os.environ.get("COMFYUI_SERVER_URL", DEFAULT_SERVER_URL)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server_available(base_url: str) -> bool:
|
||||
"""
|
||||
Check if the server is available
|
||||
|
||||
Args:
|
||||
base_url: Base URL for the API
|
||||
|
||||
Returns:
|
||||
True if the server is available, False otherwise
|
||||
"""
|
||||
try:
|
||||
response = requests.get(base_url, timeout=2)
|
||||
return response.status_code == 200
|
||||
except requests.RequestException:
|
||||
logger.warning(f"Server at {base_url} is not available")
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_client(base_url: str) -> Generator[Optional[requests.Session], None, None]:
|
||||
"""
|
||||
Create a requests session for API testing
|
||||
|
||||
Args:
|
||||
base_url: Base URL for the API
|
||||
|
||||
Yields:
|
||||
Requests session configured for the API
|
||||
"""
|
||||
session = requests.Session()
|
||||
|
||||
# Helper function to construct URLs
|
||||
def get_url(path: str) -> str:
|
||||
return urljoin(base_url, path)
|
||||
|
||||
# Add url helper to the session
|
||||
session.get_url = get_url # type: ignore
|
||||
|
||||
yield session
|
||||
|
||||
# Cleanup
|
||||
session.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_get_json(api_client: requests.Session):
|
||||
"""
|
||||
Helper fixture for making GET requests and parsing JSON responses
|
||||
|
||||
Args:
|
||||
api_client: API client session
|
||||
|
||||
Returns:
|
||||
Function that makes GET requests and returns JSON
|
||||
"""
|
||||
def _get_json(path: str, **kwargs):
|
||||
url = api_client.get_url(path) # type: ignore
|
||||
response = api_client.get(url, **kwargs)
|
||||
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
return response.json()
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
return _get_json
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def require_server(server_available):
|
||||
"""
|
||||
Skip tests if server is not available
|
||||
|
||||
Args:
|
||||
server_available: Whether the server is available
|
||||
"""
|
||||
if not server_available:
|
||||
pytest.skip("Server is not available")
|
||||
6
tests-api/requirements.txt
Normal file
6
tests-api/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
pytest>=7.0.0
|
||||
pytest-asyncio>=0.21.0
|
||||
openapi-spec-validator>=0.5.0
|
||||
jsonschema>=4.17.0
|
||||
requests>=2.28.0
|
||||
pyyaml>=6.0.0
|
||||
279
tests-api/test_api_by_tag.py
Normal file
279
tests-api/test_api_by_tag.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
Tests for API endpoints grouped by tags
|
||||
"""
|
||||
import pytest
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, Any, Set
|
||||
|
||||
# Use a direct import with the full path
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, current_dir)
|
||||
|
||||
# Define functions inline to avoid import issues
|
||||
def get_all_endpoints(spec):
|
||||
"""
|
||||
Extract all endpoints from an OpenAPI spec
|
||||
"""
|
||||
endpoints = []
|
||||
|
||||
for path, path_item in spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
continue
|
||||
|
||||
endpoints.append({
|
||||
'path': path,
|
||||
'method': method.lower(),
|
||||
'tags': operation.get('tags', []),
|
||||
'operation_id': operation.get('operationId', ''),
|
||||
'summary': operation.get('summary', '')
|
||||
})
|
||||
|
||||
return endpoints
|
||||
|
||||
def get_all_tags(spec):
|
||||
"""
|
||||
Get all tags used in the API spec
|
||||
"""
|
||||
tags = set()
|
||||
|
||||
for path_item in spec['paths'].values():
|
||||
for operation in path_item.values():
|
||||
if isinstance(operation, dict) and 'tags' in operation:
|
||||
tags.update(operation['tags'])
|
||||
|
||||
return tags
|
||||
|
||||
def extract_endpoints_by_tag(spec, tag):
|
||||
"""
|
||||
Extract all endpoints with a specific tag
|
||||
"""
|
||||
endpoints = []
|
||||
|
||||
for path, path_item in spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
continue
|
||||
|
||||
if tag in operation.get('tags', []):
|
||||
endpoints.append({
|
||||
'path': path,
|
||||
'method': method.lower(),
|
||||
'operation_id': operation.get('operationId', ''),
|
||||
'summary': operation.get('summary', '')
|
||||
})
|
||||
|
||||
return endpoints
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_tags(api_spec: Dict[str, Any]) -> Set[str]:
|
||||
"""
|
||||
Get all tags from the API spec
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
|
||||
Returns:
|
||||
Set of tag names
|
||||
"""
|
||||
return get_all_tags(api_spec)
|
||||
|
||||
|
||||
def test_api_has_tags(api_tags: Set[str]):
|
||||
"""
|
||||
Test that the API has defined tags
|
||||
|
||||
Args:
|
||||
api_tags: Set of tags
|
||||
"""
|
||||
assert len(api_tags) > 0, "API spec should have at least one tag"
|
||||
|
||||
# Log the tags
|
||||
logger.info(f"API spec has the following tags: {sorted(api_tags)}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tag", [
|
||||
"workflow",
|
||||
"image",
|
||||
"model",
|
||||
"node",
|
||||
"system"
|
||||
])
|
||||
def test_core_tags_exist(api_tags: Set[str], tag: str):
|
||||
"""
|
||||
Test that core tags exist in the API spec
|
||||
|
||||
Args:
|
||||
api_tags: Set of tags
|
||||
tag: Tag to check
|
||||
"""
|
||||
assert tag in api_tags, f"API spec should have '{tag}' tag"
|
||||
|
||||
|
||||
def test_workflow_tag_has_endpoints(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the 'workflow' tag has appropriate endpoints
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
endpoints = extract_endpoints_by_tag(api_spec, "workflow")
|
||||
|
||||
assert len(endpoints) > 0, "No endpoints found with 'workflow' tag"
|
||||
|
||||
# Check for key workflow endpoints
|
||||
endpoint_paths = [e["path"] for e in endpoints]
|
||||
assert "/prompt" in endpoint_paths, "Workflow tag should include /prompt endpoint"
|
||||
|
||||
# Log the endpoints
|
||||
logger.info(f"Found {len(endpoints)} endpoints with 'workflow' tag:")
|
||||
for e in endpoints:
|
||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
||||
|
||||
|
||||
def test_image_tag_has_endpoints(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the 'image' tag has appropriate endpoints
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
endpoints = extract_endpoints_by_tag(api_spec, "image")
|
||||
|
||||
assert len(endpoints) > 0, "No endpoints found with 'image' tag"
|
||||
|
||||
# Check for key image endpoints
|
||||
endpoint_paths = [e["path"] for e in endpoints]
|
||||
assert "/upload/image" in endpoint_paths, "Image tag should include /upload/image endpoint"
|
||||
assert "/view" in endpoint_paths, "Image tag should include /view endpoint"
|
||||
|
||||
# Log the endpoints
|
||||
logger.info(f"Found {len(endpoints)} endpoints with 'image' tag:")
|
||||
for e in endpoints:
|
||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
||||
|
||||
|
||||
def test_model_tag_has_endpoints(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the 'model' tag has appropriate endpoints
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
endpoints = extract_endpoints_by_tag(api_spec, "model")
|
||||
|
||||
assert len(endpoints) > 0, "No endpoints found with 'model' tag"
|
||||
|
||||
# Check for key model endpoints
|
||||
endpoint_paths = [e["path"] for e in endpoints]
|
||||
assert "/models" in endpoint_paths, "Model tag should include /models endpoint"
|
||||
|
||||
# Log the endpoints
|
||||
logger.info(f"Found {len(endpoints)} endpoints with 'model' tag:")
|
||||
for e in endpoints:
|
||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
||||
|
||||
|
||||
def test_node_tag_has_endpoints(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the 'node' tag has appropriate endpoints
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
endpoints = extract_endpoints_by_tag(api_spec, "node")
|
||||
|
||||
assert len(endpoints) > 0, "No endpoints found with 'node' tag"
|
||||
|
||||
# Check for key node endpoints
|
||||
endpoint_paths = [e["path"] for e in endpoints]
|
||||
assert "/object_info" in endpoint_paths, "Node tag should include /object_info endpoint"
|
||||
|
||||
# Log the endpoints
|
||||
logger.info(f"Found {len(endpoints)} endpoints with 'node' tag:")
|
||||
for e in endpoints:
|
||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
||||
|
||||
|
||||
def test_system_tag_has_endpoints(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the 'system' tag has appropriate endpoints
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
endpoints = extract_endpoints_by_tag(api_spec, "system")
|
||||
|
||||
assert len(endpoints) > 0, "No endpoints found with 'system' tag"
|
||||
|
||||
# Check for key system endpoints
|
||||
endpoint_paths = [e["path"] for e in endpoints]
|
||||
assert "/system_stats" in endpoint_paths, "System tag should include /system_stats endpoint"
|
||||
|
||||
# Log the endpoints
|
||||
logger.info(f"Found {len(endpoints)} endpoints with 'system' tag:")
|
||||
for e in endpoints:
|
||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
||||
|
||||
|
||||
def test_internal_tag_has_endpoints(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the 'internal' tag has appropriate endpoints
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
endpoints = extract_endpoints_by_tag(api_spec, "internal")
|
||||
|
||||
assert len(endpoints) > 0, "No endpoints found with 'internal' tag"
|
||||
|
||||
# Check for key internal endpoints
|
||||
endpoint_paths = [e["path"] for e in endpoints]
|
||||
assert "/internal/logs" in endpoint_paths, "Internal tag should include /internal/logs endpoint"
|
||||
|
||||
# Log the endpoints
|
||||
logger.info(f"Found {len(endpoints)} endpoints with 'internal' tag:")
|
||||
for e in endpoints:
|
||||
logger.info(f" {e['method'].upper()} {e['path']}")
|
||||
|
||||
|
||||
def test_operation_ids_match_tag(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that operation IDs follow a consistent pattern with their tag
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
failures = []
|
||||
|
||||
for path, path_item in api_spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
if 'operationId' in operation and 'tags' in operation and operation['tags']:
|
||||
op_id = operation['operationId']
|
||||
primary_tag = operation['tags'][0].lower()
|
||||
|
||||
# Check if operationId starts with primary tag prefix
|
||||
# This is a common convention, but might need adjusting
|
||||
if not (op_id.startswith(primary_tag) or
|
||||
any(op_id.lower().startswith(f"{tag.lower()}") for tag in operation['tags'])):
|
||||
failures.append({
|
||||
'path': path,
|
||||
'method': method,
|
||||
'operationId': op_id,
|
||||
'primary_tag': primary_tag
|
||||
})
|
||||
|
||||
# Log failures for diagnosis but don't fail the test
|
||||
# as this is a style/convention check
|
||||
if failures:
|
||||
logger.warning(f"Found {len(failures)} operationIds that don't align with their tags:")
|
||||
for f in failures:
|
||||
logger.warning(f" {f['method'].upper()} {f['path']} - operationId: {f['operationId']}, primary tag: {f['primary_tag']}")
|
||||
240
tests-api/test_endpoint_existence.py
Normal file
240
tests-api/test_endpoint_existence.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Tests for endpoint existence and basic response codes
|
||||
"""
|
||||
import pytest
|
||||
import requests
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, Any, List
|
||||
from urllib.parse import urljoin
|
||||
|
||||
# Use a direct import with the full path
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, current_dir)
|
||||
|
||||
# Define get_all_endpoints function inline to avoid import issues
|
||||
def get_all_endpoints(spec):
|
||||
"""
|
||||
Extract all endpoints from an OpenAPI spec
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
|
||||
Returns:
|
||||
List of dicts with path, method, and tags for each endpoint
|
||||
"""
|
||||
endpoints = []
|
||||
|
||||
for path, path_item in spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
continue
|
||||
|
||||
endpoints.append({
|
||||
'path': path,
|
||||
'method': method.lower(),
|
||||
'tags': operation.get('tags', []),
|
||||
'operation_id': operation.get('operationId', ''),
|
||||
'summary': operation.get('summary', '')
|
||||
})
|
||||
|
||||
return endpoints
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def all_endpoints(api_spec: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all endpoints from the API spec
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
|
||||
Returns:
|
||||
List of endpoint information
|
||||
"""
|
||||
return get_all_endpoints(api_spec)
|
||||
|
||||
|
||||
def test_endpoints_exist(all_endpoints: List[Dict[str, Any]]):
|
||||
"""
|
||||
Test that endpoints are defined in the spec
|
||||
|
||||
Args:
|
||||
all_endpoints: List of endpoint information
|
||||
"""
|
||||
# Simple check that we have endpoints defined
|
||||
assert len(all_endpoints) > 0, "No endpoints defined in the OpenAPI spec"
|
||||
|
||||
# Log the endpoints for informational purposes
|
||||
logger.info(f"Found {len(all_endpoints)} endpoints in the OpenAPI spec")
|
||||
for endpoint in all_endpoints:
|
||||
logger.info(f"{endpoint['method'].upper()} {endpoint['path']} - {endpoint['summary']}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("endpoint_path", [
|
||||
"/", # Root path
|
||||
"/prompt", # Get prompt info
|
||||
"/queue", # Get queue
|
||||
"/models", # Get model types
|
||||
"/object_info", # Get node info
|
||||
"/system_stats" # Get system stats
|
||||
])
|
||||
def test_basic_get_endpoints(require_server, api_client, endpoint_path: str):
|
||||
"""
|
||||
Test that basic GET endpoints exist and respond
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
endpoint_path: Path to test
|
||||
"""
|
||||
url = api_client.get_url(endpoint_path) # type: ignore
|
||||
|
||||
try:
|
||||
response = api_client.get(url)
|
||||
|
||||
# We're just checking that the endpoint exists and returns some kind of response
|
||||
# Not necessarily a 200 status code
|
||||
assert response.status_code not in [404, 405], f"Endpoint {endpoint_path} does not exist"
|
||||
|
||||
logger.info(f"Endpoint {endpoint_path} exists with status code {response.status_code}")
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request to {endpoint_path} failed: {str(e)}")
|
||||
|
||||
|
||||
def test_websocket_endpoint_exists(require_server, base_url: str):
|
||||
"""
|
||||
Test that the WebSocket endpoint exists
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
base_url: Base server URL
|
||||
"""
|
||||
ws_url = urljoin(base_url, "/ws")
|
||||
|
||||
# For WebSocket, we can't use a normal GET request
|
||||
# Instead, we make a HEAD request to check if the endpoint exists
|
||||
try:
|
||||
response = requests.head(ws_url)
|
||||
|
||||
# WebSocket endpoints often return a 400 Bad Request for HEAD requests
|
||||
# but a 404 would indicate the endpoint doesn't exist
|
||||
assert response.status_code != 404, "WebSocket endpoint /ws does not exist"
|
||||
|
||||
logger.info(f"WebSocket endpoint exists with status code {response.status_code}")
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request to WebSocket endpoint failed: {str(e)}")
|
||||
|
||||
|
||||
def test_api_models_folder_endpoint(require_server, api_client):
|
||||
"""
|
||||
Test that the /models/{folder} endpoint exists and responds
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
"""
|
||||
# First get available model types
|
||||
models_url = api_client.get_url("/models") # type: ignore
|
||||
|
||||
try:
|
||||
models_response = api_client.get(models_url)
|
||||
assert models_response.status_code == 200, "Failed to get model types"
|
||||
|
||||
model_types = models_response.json()
|
||||
|
||||
# Skip if no model types available
|
||||
if not model_types:
|
||||
pytest.skip("No model types available to test")
|
||||
|
||||
# Test with the first model type
|
||||
model_type = model_types[0]
|
||||
models_folder_url = api_client.get_url(f"/models/{model_type}") # type: ignore
|
||||
|
||||
folder_response = api_client.get(models_folder_url)
|
||||
|
||||
# We're just checking that the endpoint exists
|
||||
assert folder_response.status_code != 404, f"Endpoint /models/{model_type} does not exist"
|
||||
|
||||
logger.info(f"Endpoint /models/{model_type} exists with status code {folder_response.status_code}")
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request failed: {str(e)}")
|
||||
except (ValueError, KeyError, IndexError) as e:
|
||||
pytest.fail(f"Failed to process response: {str(e)}")
|
||||
|
||||
|
||||
def test_api_object_info_node_endpoint(require_server, api_client):
|
||||
"""
|
||||
Test that the /object_info/{node_class} endpoint exists and responds
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
"""
|
||||
# First get available node classes
|
||||
objects_url = api_client.get_url("/object_info") # type: ignore
|
||||
|
||||
try:
|
||||
objects_response = api_client.get(objects_url)
|
||||
assert objects_response.status_code == 200, "Failed to get object info"
|
||||
|
||||
node_classes = objects_response.json()
|
||||
|
||||
# Skip if no node classes available
|
||||
if not node_classes:
|
||||
pytest.skip("No node classes available to test")
|
||||
|
||||
# Test with the first node class
|
||||
node_class = next(iter(node_classes.keys()))
|
||||
node_url = api_client.get_url(f"/object_info/{node_class}") # type: ignore
|
||||
|
||||
node_response = api_client.get(node_url)
|
||||
|
||||
# We're just checking that the endpoint exists
|
||||
assert node_response.status_code != 404, f"Endpoint /object_info/{node_class} does not exist"
|
||||
|
||||
logger.info(f"Endpoint /object_info/{node_class} exists with status code {node_response.status_code}")
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request failed: {str(e)}")
|
||||
except (ValueError, KeyError, StopIteration) as e:
|
||||
pytest.fail(f"Failed to process response: {str(e)}")
|
||||
|
||||
|
||||
def test_internal_endpoints_exist(require_server, api_client):
|
||||
"""
|
||||
Test that internal endpoints exist
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
"""
|
||||
internal_endpoints = [
|
||||
"/internal/logs",
|
||||
"/internal/logs/raw",
|
||||
"/internal/folder_paths",
|
||||
"/internal/files/output"
|
||||
]
|
||||
|
||||
for endpoint in internal_endpoints:
|
||||
url = api_client.get_url(endpoint) # type: ignore
|
||||
|
||||
try:
|
||||
response = api_client.get(url)
|
||||
|
||||
# We're just checking that the endpoint exists
|
||||
assert response.status_code != 404, f"Endpoint {endpoint} does not exist"
|
||||
|
||||
logger.info(f"Endpoint {endpoint} exists with status code {response.status_code}")
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.warning(f"Request to {endpoint} failed: {str(e)}")
|
||||
# Don't fail the test as internal endpoints might be restricted
|
||||
440
tests-api/test_schema_validation.py
Normal file
440
tests-api/test_schema_validation.py
Normal file
@@ -0,0 +1,440 @@
|
||||
"""
|
||||
Tests for validating API responses against OpenAPI schema
|
||||
"""
|
||||
import pytest
|
||||
import requests
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
# Use a direct import with the full path
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, current_dir)
|
||||
|
||||
# Define validation functions inline to avoid import issues
|
||||
def get_endpoint_schema(
|
||||
spec,
|
||||
path,
|
||||
method,
|
||||
status_code = '200'
|
||||
):
|
||||
"""
|
||||
Extract response schema for a specific endpoint from OpenAPI spec
|
||||
"""
|
||||
method = method.lower()
|
||||
|
||||
# Handle path not found
|
||||
if path not in spec['paths']:
|
||||
return None
|
||||
|
||||
# Handle method not found
|
||||
if method not in spec['paths'][path]:
|
||||
return None
|
||||
|
||||
# Handle status code not found
|
||||
responses = spec['paths'][path][method].get('responses', {})
|
||||
if status_code not in responses:
|
||||
return None
|
||||
|
||||
# Handle no content defined
|
||||
if 'content' not in responses[status_code]:
|
||||
return None
|
||||
|
||||
# Get schema from first content type
|
||||
content_types = responses[status_code]['content']
|
||||
first_content_type = next(iter(content_types))
|
||||
|
||||
if 'schema' not in content_types[first_content_type]:
|
||||
return None
|
||||
|
||||
return content_types[first_content_type]['schema']
|
||||
|
||||
def resolve_schema_refs(schema, spec):
|
||||
"""
|
||||
Resolve $ref references in a schema
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
result = {}
|
||||
|
||||
for key, value in schema.items():
|
||||
if key == '$ref' and isinstance(value, str) and value.startswith('#/'):
|
||||
# Handle reference
|
||||
ref_path = value[2:].split('/')
|
||||
ref_value = spec
|
||||
for path_part in ref_path:
|
||||
ref_value = ref_value.get(path_part, {})
|
||||
|
||||
# Recursively resolve any refs in the referenced schema
|
||||
ref_value = resolve_schema_refs(ref_value, spec)
|
||||
result.update(ref_value)
|
||||
elif isinstance(value, dict):
|
||||
# Recursively resolve refs in nested dictionaries
|
||||
result[key] = resolve_schema_refs(value, spec)
|
||||
elif isinstance(value, list):
|
||||
# Recursively resolve refs in list items
|
||||
result[key] = [
|
||||
resolve_schema_refs(item, spec) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
# Pass through other values
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
def validate_response(
|
||||
response_data,
|
||||
spec,
|
||||
path,
|
||||
method,
|
||||
status_code = '200'
|
||||
):
|
||||
"""
|
||||
Validate a response against the OpenAPI schema
|
||||
"""
|
||||
schema = get_endpoint_schema(spec, path, method, status_code)
|
||||
|
||||
if schema is None:
|
||||
return {
|
||||
'valid': False,
|
||||
'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"]
|
||||
}
|
||||
|
||||
# Resolve any $ref in the schema
|
||||
resolved_schema = resolve_schema_refs(schema, spec)
|
||||
|
||||
try:
|
||||
import jsonschema
|
||||
jsonschema.validate(instance=response_data, schema=resolved_schema)
|
||||
return {'valid': True, 'errors': []}
|
||||
except jsonschema.exceptions.ValidationError as e:
|
||||
# Extract more detailed error information
|
||||
path = ".".join(str(p) for p in e.path) if e.path else "root"
|
||||
instance = e.instance if not isinstance(e.instance, dict) else "..."
|
||||
schema_path = ".".join(str(p) for p in e.schema_path) if e.schema_path else "unknown"
|
||||
|
||||
detailed_error = (
|
||||
f"Validation error at path: {path}\n"
|
||||
f"Schema path: {schema_path}\n"
|
||||
f"Error message: {e.message}\n"
|
||||
f"Failed instance: {instance}\n"
|
||||
)
|
||||
|
||||
return {'valid': False, 'errors': [detailed_error]}
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("endpoint_path,method", [
|
||||
("/system_stats", "get"),
|
||||
("/prompt", "get"),
|
||||
("/queue", "get"),
|
||||
("/models", "get"),
|
||||
("/embeddings", "get")
|
||||
])
|
||||
def test_response_schema_validation(
|
||||
require_server,
|
||||
api_client,
|
||||
api_spec: Dict[str, Any],
|
||||
endpoint_path: str,
|
||||
method: str
|
||||
):
|
||||
"""
|
||||
Test that API responses match the defined schema
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
api_spec: Loaded OpenAPI spec
|
||||
endpoint_path: Path to test
|
||||
method: HTTP method to test
|
||||
"""
|
||||
url = api_client.get_url(endpoint_path) # type: ignore
|
||||
|
||||
# Skip if no schema defined
|
||||
schema = get_endpoint_schema(api_spec, endpoint_path, method)
|
||||
if not schema:
|
||||
pytest.skip(f"No schema defined for {method.upper()} {endpoint_path}")
|
||||
|
||||
try:
|
||||
if method.lower() == "get":
|
||||
response = api_client.get(url)
|
||||
else:
|
||||
pytest.skip(f"Method {method} not implemented for automated testing")
|
||||
return
|
||||
|
||||
# Skip if response is not 200
|
||||
if response.status_code != 200:
|
||||
pytest.skip(f"Endpoint {endpoint_path} returned status {response.status_code}")
|
||||
return
|
||||
|
||||
# Skip if response is not JSON
|
||||
try:
|
||||
response_data = response.json()
|
||||
except ValueError:
|
||||
pytest.skip(f"Endpoint {endpoint_path} did not return valid JSON")
|
||||
return
|
||||
|
||||
# Validate the response
|
||||
validation_result = validate_response(
|
||||
response_data,
|
||||
api_spec,
|
||||
endpoint_path,
|
||||
method
|
||||
)
|
||||
|
||||
if validation_result['valid']:
|
||||
logger.info(f"Response from {method.upper()} {endpoint_path} matches schema")
|
||||
else:
|
||||
for error in validation_result['errors']:
|
||||
logger.error(f"Validation error for {method.upper()} {endpoint_path}: {error}")
|
||||
|
||||
assert validation_result['valid'], f"Response from {method.upper()} {endpoint_path} does not match schema"
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request to {endpoint_path} failed: {str(e)}")
|
||||
|
||||
|
||||
def test_system_stats_response(require_server, api_client, api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test the system_stats endpoint response in detail
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
url = api_client.get_url("/system_stats") # type: ignore
|
||||
|
||||
try:
|
||||
response = api_client.get(url)
|
||||
|
||||
assert response.status_code == 200, "Failed to get system stats"
|
||||
|
||||
# Parse response
|
||||
stats = response.json()
|
||||
|
||||
# Validate high-level structure
|
||||
assert 'system' in stats, "Response missing 'system' field"
|
||||
assert 'devices' in stats, "Response missing 'devices' field"
|
||||
|
||||
# Validate system fields
|
||||
system = stats['system']
|
||||
assert 'os' in system, "System missing 'os' field"
|
||||
assert 'ram_total' in system, "System missing 'ram_total' field"
|
||||
assert 'ram_free' in system, "System missing 'ram_free' field"
|
||||
assert 'comfyui_version' in system, "System missing 'comfyui_version' field"
|
||||
|
||||
# Validate devices fields
|
||||
devices = stats['devices']
|
||||
assert isinstance(devices, list), "Devices should be a list"
|
||||
|
||||
if devices:
|
||||
device = devices[0]
|
||||
assert 'name' in device, "Device missing 'name' field"
|
||||
assert 'type' in device, "Device missing 'type' field"
|
||||
assert 'vram_total' in device, "Device missing 'vram_total' field"
|
||||
assert 'vram_free' in device, "Device missing 'vram_free' field"
|
||||
|
||||
# Perform schema validation
|
||||
validation_result = validate_response(
|
||||
stats,
|
||||
api_spec,
|
||||
"/system_stats",
|
||||
"get"
|
||||
)
|
||||
|
||||
# Print detailed error if validation fails
|
||||
if not validation_result['valid']:
|
||||
for error in validation_result['errors']:
|
||||
logger.error(f"Validation error for /system_stats: {error}")
|
||||
|
||||
# Print schema details for debugging
|
||||
schema = get_endpoint_schema(api_spec, "/system_stats", "get")
|
||||
if schema:
|
||||
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
||||
|
||||
# Print sample of the response
|
||||
logger.error(f"Response:\n{json.dumps(stats, indent=2)}")
|
||||
|
||||
assert validation_result['valid'], "System stats response does not match schema"
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request to /system_stats failed: {str(e)}")
|
||||
|
||||
|
||||
def test_models_listing_response(require_server, api_client, api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test the models endpoint response
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
url = api_client.get_url("/models") # type: ignore
|
||||
|
||||
try:
|
||||
response = api_client.get(url)
|
||||
|
||||
assert response.status_code == 200, "Failed to get models"
|
||||
|
||||
# Parse response
|
||||
models = response.json()
|
||||
|
||||
# Validate it's a list
|
||||
assert isinstance(models, list), "Models response should be a list"
|
||||
|
||||
# Each item should be a string
|
||||
for model in models:
|
||||
assert isinstance(model, str), "Each model type should be a string"
|
||||
|
||||
# Perform schema validation
|
||||
validation_result = validate_response(
|
||||
models,
|
||||
api_spec,
|
||||
"/models",
|
||||
"get"
|
||||
)
|
||||
|
||||
# Print detailed error if validation fails
|
||||
if not validation_result['valid']:
|
||||
for error in validation_result['errors']:
|
||||
logger.error(f"Validation error for /models: {error}")
|
||||
|
||||
# Print schema details for debugging
|
||||
schema = get_endpoint_schema(api_spec, "/models", "get")
|
||||
if schema:
|
||||
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
||||
|
||||
# Print response
|
||||
sample_models = models[:5] if isinstance(models, list) else models
|
||||
logger.error(f"Models response:\n{json.dumps(sample_models, indent=2)}")
|
||||
|
||||
assert validation_result['valid'], "Models response does not match schema"
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request to /models failed: {str(e)}")
|
||||
|
||||
|
||||
def test_object_info_response(require_server, api_client, api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test the object_info endpoint response
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
url = api_client.get_url("/object_info") # type: ignore
|
||||
|
||||
try:
|
||||
response = api_client.get(url)
|
||||
|
||||
assert response.status_code == 200, "Failed to get object info"
|
||||
|
||||
# Parse response
|
||||
objects = response.json()
|
||||
|
||||
# Validate it's an object
|
||||
assert isinstance(objects, dict), "Object info response should be an object"
|
||||
|
||||
# Check if we have any objects
|
||||
if objects:
|
||||
# Get the first object
|
||||
first_obj_name = next(iter(objects.keys()))
|
||||
first_obj = objects[first_obj_name]
|
||||
|
||||
# Validate first object has required fields
|
||||
assert 'input' in first_obj, "Object missing 'input' field"
|
||||
assert 'output' in first_obj, "Object missing 'output' field"
|
||||
assert 'name' in first_obj, "Object missing 'name' field"
|
||||
|
||||
# Perform schema validation
|
||||
validation_result = validate_response(
|
||||
objects,
|
||||
api_spec,
|
||||
"/object_info",
|
||||
"get"
|
||||
)
|
||||
|
||||
# Print detailed error if validation fails
|
||||
if not validation_result['valid']:
|
||||
for error in validation_result['errors']:
|
||||
logger.error(f"Validation error for /object_info: {error}")
|
||||
|
||||
# Print schema details for debugging
|
||||
schema = get_endpoint_schema(api_spec, "/object_info", "get")
|
||||
if schema:
|
||||
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
||||
|
||||
# Also print a small sample of the response
|
||||
sample = dict(list(objects.items())[:1]) if objects else {}
|
||||
logger.error(f"Sample response:\n{json.dumps(sample, indent=2)}")
|
||||
|
||||
assert validation_result['valid'], "Object info response does not match schema"
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request to /object_info failed: {str(e)}")
|
||||
except (KeyError, StopIteration) as e:
|
||||
pytest.fail(f"Failed to process response: {str(e)}")
|
||||
|
||||
|
||||
def test_queue_response(require_server, api_client, api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test the queue endpoint response
|
||||
|
||||
Args:
|
||||
require_server: Fixture that skips if server is not available
|
||||
api_client: API client fixture
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
url = api_client.get_url("/queue") # type: ignore
|
||||
|
||||
try:
|
||||
response = api_client.get(url)
|
||||
|
||||
assert response.status_code == 200, "Failed to get queue"
|
||||
|
||||
# Parse response
|
||||
queue = response.json()
|
||||
|
||||
# Validate structure
|
||||
assert 'queue_running' in queue, "Queue missing 'queue_running' field"
|
||||
assert 'queue_pending' in queue, "Queue missing 'queue_pending' field"
|
||||
|
||||
# Each should be a list
|
||||
assert isinstance(queue['queue_running'], list), "queue_running should be a list"
|
||||
assert isinstance(queue['queue_pending'], list), "queue_pending should be a list"
|
||||
|
||||
# Perform schema validation
|
||||
validation_result = validate_response(
|
||||
queue,
|
||||
api_spec,
|
||||
"/queue",
|
||||
"get"
|
||||
)
|
||||
|
||||
# Print detailed error if validation fails
|
||||
if not validation_result['valid']:
|
||||
for error in validation_result['errors']:
|
||||
logger.error(f"Validation error for /queue: {error}")
|
||||
|
||||
# Print schema details for debugging
|
||||
schema = get_endpoint_schema(api_spec, "/queue", "get")
|
||||
if schema:
|
||||
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
||||
|
||||
# Print response
|
||||
logger.error(f"Queue response:\n{json.dumps(queue, indent=2)}")
|
||||
|
||||
assert validation_result['valid'], "Queue response does not match schema"
|
||||
|
||||
except requests.RequestException as e:
|
||||
pytest.fail(f"Request to /queue failed: {str(e)}")
|
||||
144
tests-api/test_spec_validation.py
Normal file
144
tests-api/test_spec_validation.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Tests for validating the OpenAPI specification
|
||||
"""
|
||||
import pytest
|
||||
from openapi_spec_validator import validate_spec
|
||||
from openapi_spec_validator.exceptions import OpenAPISpecValidatorError
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def test_openapi_spec_is_valid(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the OpenAPI specification is valid
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
try:
|
||||
validate_spec(api_spec)
|
||||
except OpenAPISpecValidatorError as e:
|
||||
pytest.fail(f"OpenAPI spec validation failed: {str(e)}")
|
||||
|
||||
|
||||
def test_spec_has_info(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the OpenAPI spec has the required info section
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
assert 'info' in api_spec, "Spec must have info section"
|
||||
assert 'title' in api_spec['info'], "Info must have title"
|
||||
assert 'version' in api_spec['info'], "Info must have version"
|
||||
|
||||
|
||||
def test_spec_has_paths(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the OpenAPI spec has paths defined
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
assert 'paths' in api_spec, "Spec must have paths section"
|
||||
assert len(api_spec['paths']) > 0, "Spec must have at least one path"
|
||||
|
||||
|
||||
def test_spec_has_components(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that the OpenAPI spec has components defined
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
assert 'components' in api_spec, "Spec must have components section"
|
||||
assert 'schemas' in api_spec['components'], "Components must have schemas"
|
||||
|
||||
|
||||
def test_workflow_endpoints_exist(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that core workflow endpoints are defined
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
assert '/prompt' in api_spec['paths'], "Spec must define /prompt endpoint"
|
||||
assert 'post' in api_spec['paths']['/prompt'], "Spec must define POST /prompt"
|
||||
assert 'get' in api_spec['paths']['/prompt'], "Spec must define GET /prompt"
|
||||
|
||||
|
||||
def test_image_endpoints_exist(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that core image endpoints are defined
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
assert '/upload/image' in api_spec['paths'], "Spec must define /upload/image endpoint"
|
||||
assert '/view' in api_spec['paths'], "Spec must define /view endpoint"
|
||||
|
||||
|
||||
def test_model_endpoints_exist(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that core model endpoints are defined
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
assert '/models' in api_spec['paths'], "Spec must define /models endpoint"
|
||||
assert '/models/{folder}' in api_spec['paths'], "Spec must define /models/{folder} endpoint"
|
||||
|
||||
|
||||
def test_operation_ids_are_unique(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that all operationIds are unique
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
operation_ids = []
|
||||
|
||||
for path, path_item in api_spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
if 'operationId' in operation:
|
||||
operation_ids.append(operation['operationId'])
|
||||
|
||||
# Check for duplicates
|
||||
duplicates = set([op_id for op_id in operation_ids if operation_ids.count(op_id) > 1])
|
||||
assert len(duplicates) == 0, f"Found duplicate operationIds: {duplicates}"
|
||||
|
||||
|
||||
def test_all_endpoints_have_operation_ids(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that all endpoints have operationIds
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
missing = []
|
||||
|
||||
for path, path_item in api_spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
if 'operationId' not in operation:
|
||||
missing.append(f"{method.upper()} {path}")
|
||||
|
||||
assert len(missing) == 0, f"Found endpoints without operationIds: {missing}"
|
||||
|
||||
|
||||
def test_all_endpoints_have_tags(api_spec: Dict[str, Any]):
|
||||
"""
|
||||
Test that all endpoints have tags
|
||||
|
||||
Args:
|
||||
api_spec: Loaded OpenAPI spec
|
||||
"""
|
||||
missing = []
|
||||
|
||||
for path, path_item in api_spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
if 'tags' not in operation or not operation['tags']:
|
||||
missing.append(f"{method.upper()} {path}")
|
||||
|
||||
assert len(missing) == 0, f"Found endpoints without tags: {missing}"
|
||||
157
tests-api/utils/schema_utils.py
Normal file
157
tests-api/utils/schema_utils.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Utilities for working with OpenAPI schemas
|
||||
"""
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
|
||||
def extract_required_parameters(
|
||||
spec: Dict[str, Any],
|
||||
path: str,
|
||||
method: str
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Extract required parameters for a specific endpoint
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
path: API path (e.g., '/prompt')
|
||||
method: HTTP method (e.g., 'get', 'post')
|
||||
|
||||
Returns:
|
||||
Tuple of (path_params, query_params) containing required parameters
|
||||
"""
|
||||
method = method.lower()
|
||||
path_params = []
|
||||
query_params = []
|
||||
|
||||
# Handle path not found
|
||||
if path not in spec['paths']:
|
||||
return path_params, query_params
|
||||
|
||||
# Handle method not found
|
||||
if method not in spec['paths'][path]:
|
||||
return path_params, query_params
|
||||
|
||||
# Get parameters
|
||||
params = spec['paths'][path][method].get('parameters', [])
|
||||
|
||||
for param in params:
|
||||
if param.get('required', False):
|
||||
if param.get('in') == 'path':
|
||||
path_params.append(param)
|
||||
elif param.get('in') == 'query':
|
||||
query_params.append(param)
|
||||
|
||||
return path_params, query_params
|
||||
|
||||
|
||||
def get_request_body_schema(
|
||||
spec: Dict[str, Any],
|
||||
path: str,
|
||||
method: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get request body schema for a specific endpoint
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
path: API path (e.g., '/prompt')
|
||||
method: HTTP method (e.g., 'get', 'post')
|
||||
|
||||
Returns:
|
||||
Request body schema or None if not found
|
||||
"""
|
||||
method = method.lower()
|
||||
|
||||
# Handle path not found
|
||||
if path not in spec['paths']:
|
||||
return None
|
||||
|
||||
# Handle method not found
|
||||
if method not in spec['paths'][path]:
|
||||
return None
|
||||
|
||||
# Handle no request body
|
||||
request_body = spec['paths'][path][method].get('requestBody', {})
|
||||
if not request_body or 'content' not in request_body:
|
||||
return None
|
||||
|
||||
# Get schema from first content type
|
||||
content_types = request_body['content']
|
||||
first_content_type = next(iter(content_types))
|
||||
|
||||
if 'schema' not in content_types[first_content_type]:
|
||||
return None
|
||||
|
||||
return content_types[first_content_type]['schema']
|
||||
|
||||
|
||||
def extract_endpoints_by_tag(spec: Dict[str, Any], tag: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract all endpoints with a specific tag
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
tag: Tag to filter by
|
||||
|
||||
Returns:
|
||||
List of endpoint details
|
||||
"""
|
||||
endpoints = []
|
||||
|
||||
for path, path_item in spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
continue
|
||||
|
||||
if tag in operation.get('tags', []):
|
||||
endpoints.append({
|
||||
'path': path,
|
||||
'method': method.lower(),
|
||||
'operation_id': operation.get('operationId', ''),
|
||||
'summary': operation.get('summary', '')
|
||||
})
|
||||
|
||||
return endpoints
|
||||
|
||||
|
||||
def get_all_tags(spec: Dict[str, Any]) -> Set[str]:
|
||||
"""
|
||||
Get all tags used in the API spec
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
|
||||
Returns:
|
||||
Set of tag names
|
||||
"""
|
||||
tags = set()
|
||||
|
||||
for path_item in spec['paths'].values():
|
||||
for operation in path_item.values():
|
||||
if isinstance(operation, dict) and 'tags' in operation:
|
||||
tags.update(operation['tags'])
|
||||
|
||||
return tags
|
||||
|
||||
|
||||
def get_schema_examples(spec: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract all examples from component schemas
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
|
||||
Returns:
|
||||
Dict mapping schema names to examples
|
||||
"""
|
||||
examples = {}
|
||||
|
||||
if 'components' not in spec or 'schemas' not in spec['components']:
|
||||
return examples
|
||||
|
||||
for name, schema in spec['components']['schemas'].items():
|
||||
if 'example' in schema:
|
||||
examples[name] = schema['example']
|
||||
|
||||
return examples
|
||||
178
tests-api/utils/validation.py
Normal file
178
tests-api/utils/validation.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Utilities for API response validation against OpenAPI spec
|
||||
"""
|
||||
import yaml
|
||||
import jsonschema
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
def load_openapi_spec(spec_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Load the OpenAPI specification from a YAML file
|
||||
|
||||
Args:
|
||||
spec_path: Path to the OpenAPI specification file
|
||||
|
||||
Returns:
|
||||
Dict containing the parsed OpenAPI spec
|
||||
"""
|
||||
with open(spec_path, 'r') as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
def get_endpoint_schema(
|
||||
spec: Dict[str, Any],
|
||||
path: str,
|
||||
method: str,
|
||||
status_code: str = '200'
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract response schema for a specific endpoint from OpenAPI spec
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
path: API path (e.g., '/prompt')
|
||||
method: HTTP method (e.g., 'get', 'post')
|
||||
status_code: HTTP status code to get schema for
|
||||
|
||||
Returns:
|
||||
Schema dict or None if not found
|
||||
"""
|
||||
method = method.lower()
|
||||
|
||||
# Handle path not found
|
||||
if path not in spec['paths']:
|
||||
return None
|
||||
|
||||
# Handle method not found
|
||||
if method not in spec['paths'][path]:
|
||||
return None
|
||||
|
||||
# Handle status code not found
|
||||
responses = spec['paths'][path][method].get('responses', {})
|
||||
if status_code not in responses:
|
||||
return None
|
||||
|
||||
# Handle no content defined
|
||||
if 'content' not in responses[status_code]:
|
||||
return None
|
||||
|
||||
# Get schema from first content type
|
||||
content_types = responses[status_code]['content']
|
||||
first_content_type = next(iter(content_types))
|
||||
|
||||
if 'schema' not in content_types[first_content_type]:
|
||||
return None
|
||||
|
||||
return content_types[first_content_type]['schema']
|
||||
|
||||
|
||||
def resolve_schema_refs(schema: Dict[str, Any], spec: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Resolve $ref references in a schema
|
||||
|
||||
Args:
|
||||
schema: Schema that may contain references
|
||||
spec: Full OpenAPI spec with component definitions
|
||||
|
||||
Returns:
|
||||
Schema with references resolved
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
result = {}
|
||||
|
||||
for key, value in schema.items():
|
||||
if key == '$ref' and isinstance(value, str) and value.startswith('#/'):
|
||||
# Handle reference
|
||||
ref_path = value[2:].split('/')
|
||||
ref_value = spec
|
||||
for path_part in ref_path:
|
||||
ref_value = ref_value.get(path_part, {})
|
||||
|
||||
# Recursively resolve any refs in the referenced schema
|
||||
ref_value = resolve_schema_refs(ref_value, spec)
|
||||
result.update(ref_value)
|
||||
elif isinstance(value, dict):
|
||||
# Recursively resolve refs in nested dictionaries
|
||||
result[key] = resolve_schema_refs(value, spec)
|
||||
elif isinstance(value, list):
|
||||
# Recursively resolve refs in list items
|
||||
result[key] = [
|
||||
resolve_schema_refs(item, spec) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
# Pass through other values
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def validate_response(
|
||||
response_data: Union[Dict[str, Any], List[Any]],
|
||||
spec: Dict[str, Any],
|
||||
path: str,
|
||||
method: str,
|
||||
status_code: str = '200'
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate a response against the OpenAPI schema
|
||||
|
||||
Args:
|
||||
response_data: Response data to validate
|
||||
spec: Parsed OpenAPI specification
|
||||
path: API path (e.g., '/prompt')
|
||||
method: HTTP method (e.g., 'get', 'post')
|
||||
status_code: HTTP status code to validate against
|
||||
|
||||
Returns:
|
||||
Dict with validation result containing:
|
||||
- valid: bool indicating if validation passed
|
||||
- errors: List of validation errors if any
|
||||
"""
|
||||
schema = get_endpoint_schema(spec, path, method, status_code)
|
||||
|
||||
if schema is None:
|
||||
return {
|
||||
'valid': False,
|
||||
'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"]
|
||||
}
|
||||
|
||||
# Resolve any $ref in the schema
|
||||
resolved_schema = resolve_schema_refs(schema, spec)
|
||||
|
||||
try:
|
||||
jsonschema.validate(instance=response_data, schema=resolved_schema)
|
||||
return {'valid': True, 'errors': []}
|
||||
except jsonschema.exceptions.ValidationError as e:
|
||||
return {'valid': False, 'errors': [str(e)]}
|
||||
|
||||
|
||||
def get_all_endpoints(spec: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract all endpoints from an OpenAPI spec
|
||||
|
||||
Args:
|
||||
spec: Parsed OpenAPI specification
|
||||
|
||||
Returns:
|
||||
List of dicts with path, method, and tags for each endpoint
|
||||
"""
|
||||
endpoints = []
|
||||
|
||||
for path, path_item in spec['paths'].items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
|
||||
continue
|
||||
|
||||
endpoints.append({
|
||||
'path': path,
|
||||
'method': method.lower(),
|
||||
'tags': operation.get('tags', []),
|
||||
'operation_id': operation.get('operationId', ''),
|
||||
'summary': operation.get('summary', '')
|
||||
})
|
||||
|
||||
return endpoints
|
||||
239
tests-unit/comfy_api_test/video_types_test.py
Normal file
239
tests-unit/comfy_api_test/video_types_test.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import pytest
|
||||
import torch
|
||||
import tempfile
|
||||
import os
|
||||
import av
|
||||
import io
|
||||
from fractions import Fraction
|
||||
from comfy_api.input_impl.video_types import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.util.video_types import VideoComponents
|
||||
from comfy_api.input.basic_types import AudioInput
|
||||
from av.error import InvalidDataError
|
||||
|
||||
EPSILON = 0.0001
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_images():
|
||||
"""3-frame 2x2 RGB video tensor"""
|
||||
return torch.rand(3, 2, 2, 3)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_audio():
|
||||
"""Stereo audio with 44.1kHz sample rate"""
|
||||
return AudioInput(
|
||||
{
|
||||
"waveform": torch.rand(1, 2, 1000),
|
||||
"sample_rate": 44100,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_components(sample_images, sample_audio):
|
||||
"""VideoComponents with images, audio, and metadata"""
|
||||
return VideoComponents(
|
||||
images=sample_images,
|
||||
audio=sample_audio,
|
||||
frame_rate=Fraction(30),
|
||||
metadata={"test": "metadata"},
|
||||
)
|
||||
|
||||
|
||||
def create_test_video(width=4, height=4, frames=3, fps=30):
|
||||
"""Helper to create a temporary video file"""
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
||||
with av.open(tmp.name, mode="w") as container:
|
||||
stream = container.add_stream("h264", rate=fps)
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
for i in range(frames):
|
||||
frame = av.VideoFrame.from_ndarray(
|
||||
torch.ones(height, width, 3, dtype=torch.uint8).numpy() * (i * 85),
|
||||
format="rgb24",
|
||||
)
|
||||
frame = frame.reformat(format="yuv420p")
|
||||
packet = stream.encode(frame)
|
||||
container.mux(packet)
|
||||
|
||||
# Flush
|
||||
packet = stream.encode(None)
|
||||
container.mux(packet)
|
||||
|
||||
return tmp.name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_video_file():
|
||||
"""4x4 video with 3 frames at 30fps"""
|
||||
file_path = create_test_video()
|
||||
yield file_path
|
||||
os.unlink(file_path)
|
||||
|
||||
|
||||
def test_video_from_components_get_duration(video_components):
|
||||
"""Duration calculated correctly from frame count and frame rate"""
|
||||
video = VideoFromComponents(video_components)
|
||||
duration = video.get_duration()
|
||||
|
||||
expected_duration = 3.0 / 30.0
|
||||
assert duration == pytest.approx(expected_duration)
|
||||
|
||||
|
||||
def test_video_from_components_get_duration_different_frame_rates(sample_images):
|
||||
"""Duration correct for different frame rates including fractional"""
|
||||
# Test with 60 fps
|
||||
components_60fps = VideoComponents(images=sample_images, frame_rate=Fraction(60))
|
||||
video_60fps = VideoFromComponents(components_60fps)
|
||||
assert video_60fps.get_duration() == pytest.approx(3.0 / 60.0)
|
||||
|
||||
# Test with fractional frame rate (23.976fps)
|
||||
components_frac = VideoComponents(
|
||||
images=sample_images, frame_rate=Fraction(24000, 1001)
|
||||
)
|
||||
video_frac = VideoFromComponents(components_frac)
|
||||
expected_frac = 3.0 / (24000.0 / 1001.0)
|
||||
assert video_frac.get_duration() == pytest.approx(expected_frac)
|
||||
|
||||
|
||||
def test_video_from_components_get_duration_empty_video():
|
||||
"""Duration is zero for empty video"""
|
||||
empty_components = VideoComponents(
|
||||
images=torch.zeros(0, 2, 2, 3), frame_rate=Fraction(30)
|
||||
)
|
||||
video = VideoFromComponents(empty_components)
|
||||
assert video.get_duration() == 0.0
|
||||
|
||||
|
||||
def test_video_from_components_get_dimensions(video_components):
|
||||
"""Dimensions returned correctly from image tensor shape"""
|
||||
video = VideoFromComponents(video_components)
|
||||
width, height = video.get_dimensions()
|
||||
assert width == 2
|
||||
assert height == 2
|
||||
|
||||
|
||||
def test_video_from_file_get_duration(simple_video_file):
|
||||
"""Duration extracted from file metadata"""
|
||||
video = VideoFromFile(simple_video_file)
|
||||
duration = video.get_duration()
|
||||
assert duration == pytest.approx(0.1, abs=0.01)
|
||||
|
||||
|
||||
def test_video_from_file_get_dimensions(simple_video_file):
|
||||
"""Dimensions read from stream without decoding frames"""
|
||||
video = VideoFromFile(simple_video_file)
|
||||
width, height = video.get_dimensions()
|
||||
assert width == 4
|
||||
assert height == 4
|
||||
|
||||
|
||||
def test_video_from_file_bytesio_input():
|
||||
"""VideoFromFile works with BytesIO input"""
|
||||
buffer = io.BytesIO()
|
||||
with av.open(buffer, mode="w", format="mp4") as container:
|
||||
stream = container.add_stream("h264", rate=30)
|
||||
stream.width = 2
|
||||
stream.height = 2
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
frame = av.VideoFrame.from_ndarray(
|
||||
torch.zeros(2, 2, 3, dtype=torch.uint8).numpy(), format="rgb24"
|
||||
)
|
||||
frame = frame.reformat(format="yuv420p")
|
||||
packet = stream.encode(frame)
|
||||
container.mux(packet)
|
||||
packet = stream.encode(None)
|
||||
container.mux(packet)
|
||||
|
||||
buffer.seek(0)
|
||||
video = VideoFromFile(buffer)
|
||||
|
||||
assert video.get_dimensions() == (2, 2)
|
||||
assert video.get_duration() == pytest.approx(1 / 30, abs=0.01)
|
||||
|
||||
|
||||
def test_video_from_file_invalid_file_error():
|
||||
"""InvalidDataError raised for non-video files"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp:
|
||||
tmp.write(b"not a video file")
|
||||
tmp.flush()
|
||||
tmp_name = tmp.name
|
||||
|
||||
try:
|
||||
with pytest.raises(InvalidDataError):
|
||||
video = VideoFromFile(tmp_name)
|
||||
video.get_dimensions()
|
||||
finally:
|
||||
os.unlink(tmp_name)
|
||||
|
||||
|
||||
def test_video_from_file_audio_only_error():
|
||||
"""ValueError raised for audio-only files"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp:
|
||||
tmp_name = tmp.name
|
||||
|
||||
try:
|
||||
with av.open(tmp_name, mode="w") as container:
|
||||
stream = container.add_stream("aac", rate=44100)
|
||||
stream.sample_rate = 44100
|
||||
stream.format = "fltp"
|
||||
|
||||
audio_data = torch.zeros(1, 1024).numpy()
|
||||
audio_frame = av.AudioFrame.from_ndarray(
|
||||
audio_data, format="fltp", layout="mono"
|
||||
)
|
||||
audio_frame.sample_rate = 44100
|
||||
audio_frame.pts = 0
|
||||
packet = stream.encode(audio_frame)
|
||||
container.mux(packet)
|
||||
|
||||
for packet in stream.encode(None):
|
||||
container.mux(packet)
|
||||
|
||||
with pytest.raises(ValueError, match="No video stream found"):
|
||||
video = VideoFromFile(tmp_name)
|
||||
video.get_dimensions()
|
||||
finally:
|
||||
os.unlink(tmp_name)
|
||||
|
||||
|
||||
def test_single_frame_video():
|
||||
"""Single frame video has correct duration"""
|
||||
components = VideoComponents(
|
||||
images=torch.rand(1, 10, 10, 3), frame_rate=Fraction(1)
|
||||
)
|
||||
video = VideoFromComponents(components)
|
||||
assert video.get_duration() == 1.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"frame_rate,expected_fps",
|
||||
[
|
||||
(Fraction(24000, 1001), 24000 / 1001),
|
||||
(Fraction(30000, 1001), 30000 / 1001),
|
||||
(Fraction(25, 1), 25.0),
|
||||
(Fraction(50, 2), 25.0),
|
||||
],
|
||||
)
|
||||
def test_fractional_frame_rates(frame_rate, expected_fps):
|
||||
"""Duration calculated correctly for various fractional frame rates"""
|
||||
components = VideoComponents(images=torch.rand(100, 4, 4, 3), frame_rate=frame_rate)
|
||||
video = VideoFromComponents(components)
|
||||
duration = video.get_duration()
|
||||
expected_duration = 100.0 / expected_fps
|
||||
assert duration == pytest.approx(expected_duration)
|
||||
|
||||
|
||||
def test_duration_consistency(video_components):
|
||||
"""get_duration() consistent with manual calculation from components"""
|
||||
video = VideoFromComponents(video_components)
|
||||
|
||||
duration = video.get_duration()
|
||||
components = video.get_components()
|
||||
manual_duration = float(components.images.shape[0] / components.frame_rate)
|
||||
|
||||
assert duration == pytest.approx(manual_duration)
|
||||
Reference in New Issue
Block a user