Compare commits
58 Commits
yoland68-p
...
yoland68-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ade9a4e034 | ||
|
|
87f9130778 | ||
|
|
7e84bf5373 | ||
|
|
4f3b50ba51 | ||
|
|
e930a387d6 | ||
|
|
d8e5662822 | ||
|
|
3d44a09812 | ||
|
|
62690eddec | ||
|
|
05eb10b43a | ||
|
|
f5e4e976f4 | ||
|
|
aee2908d03 | ||
|
|
dc46db7aa4 | ||
|
|
7046983d95 | ||
|
|
1c2d45d2b5 | ||
|
|
c820ef950d | ||
|
|
6a2e4bb9e0 | ||
|
|
f1f9763b4c | ||
|
|
08368f8e00 | ||
|
|
f3ff5c40db | ||
|
|
98ff01e148 | ||
|
|
bab836d88d | ||
|
|
4a9014e201 | ||
|
|
8a7c894d54 | ||
|
|
a814f2e8cc | ||
|
|
481732a0ed | ||
|
|
2156ce9453 | ||
|
|
4136502b7a | ||
|
|
9ad287ff20 | ||
|
|
f5cacaeb14 | ||
|
|
b7ed5f57bd | ||
|
|
b4abca828e | ||
|
|
158419f3a0 | ||
|
|
640c47e7de | ||
|
|
31e9e36c94 | ||
|
|
577de83ca9 | ||
|
|
3535909eb8 | ||
|
|
235d3901fc | ||
|
|
d42613686f | ||
|
|
1b3bf0a5da | ||
|
|
ae60b150e5 | ||
|
|
42da274717 | ||
|
|
28f178a840 | ||
|
|
8ab15c863c | ||
|
|
924d771e18 | ||
|
|
02a1b01aad | ||
|
|
a692c3cca4 | ||
|
|
5d3cc85e13 | ||
|
|
c7c025b8d1 | ||
|
|
fd08e39588 | ||
|
|
56b6ee6754 | ||
|
|
cc33cd3422 | ||
|
|
b9980592c4 | ||
|
|
16417b40d9 | ||
|
|
271c9c5b9e | ||
|
|
a4e679765e | ||
|
|
0cf2e46b17 | ||
|
|
094e9ef126 | ||
|
|
1271c4ef9d |
56
.github/workflows/test-ci.yml
vendored
56
.github/workflows/test-ci.yml
vendored
@@ -20,9 +20,9 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# os: [macos, linux, windows]
|
||||
os: [macos, linux]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
os: [macos, linux, windows]
|
||||
# os: [macos, linux]
|
||||
python_version: ["3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["stable"]
|
||||
include:
|
||||
@@ -32,9 +32,9 @@ jobs:
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
# - os: windows
|
||||
# runner_label: [self-hosted, Windows]
|
||||
# flags: ""
|
||||
- os: windows
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
@@ -46,28 +46,28 @@ jobs:
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
|
||||
# test-win-nightly:
|
||||
# strategy:
|
||||
# fail-fast: true
|
||||
# matrix:
|
||||
# os: [windows]
|
||||
# python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
# cuda_version: ["12.1"]
|
||||
# torch_version: ["nightly"]
|
||||
# include:
|
||||
# - os: windows
|
||||
# runner_label: [self-hosted, Windows]
|
||||
# flags: ""
|
||||
# runs-on: ${{ matrix.runner_label }}
|
||||
# steps:
|
||||
# - name: Test Workflows
|
||||
# uses: comfy-org/comfy-action@main
|
||||
# with:
|
||||
# os: ${{ matrix.os }}
|
||||
# python_version: ${{ matrix.python_version }}
|
||||
# torch_version: ${{ matrix.torch_version }}
|
||||
# google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
# comfyui_flags: ${{ matrix.flags }}
|
||||
test-win-nightly:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix:
|
||||
os: [windows]
|
||||
python_version: ["3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: windows
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
uses: comfy-org/comfy-action@main
|
||||
with:
|
||||
os: ${{ matrix.os }}
|
||||
python_version: ${{ matrix.python_version }}
|
||||
torch_version: ${{ matrix.torch_version }}
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
|
||||
test-unix-nightly:
|
||||
strategy:
|
||||
|
||||
2
.github/workflows/test-launch.yml
vendored
2
.github/workflows/test-launch.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
path: "ComfyUI"
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.9'
|
||||
python-version: '3.10'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
13
.github/workflows/update-api-stubs.yml
vendored
13
.github/workflows/update-api-stubs.yml
vendored
@@ -22,10 +22,19 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install 'datamodel-code-generator[http]'
|
||||
npm install @redocly/cli
|
||||
|
||||
- name: Download OpenAPI spec
|
||||
run: |
|
||||
curl -o openapi.yaml https://api.comfy.org/openapi
|
||||
|
||||
- name: Filter OpenAPI spec with Redocly
|
||||
run: |
|
||||
npx @redocly/cli bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
|
||||
|
||||
- name: Generate API models
|
||||
run: |
|
||||
datamodel-codegen --use-subclass-enum --url https://api.comfy.org/openapi --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
|
||||
datamodel-codegen --use-subclass-enum --input filtered-openapi.yaml --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
|
||||
|
||||
- name: Check for changes
|
||||
id: git-check
|
||||
@@ -44,4 +53,4 @@ jobs:
|
||||
Generated automatically by the a Github workflow.
|
||||
branch: update-api-stubs
|
||||
delete-branch: true
|
||||
base: main
|
||||
base: master
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -21,3 +21,6 @@ venv/
|
||||
*.log
|
||||
web_custom_versions/
|
||||
.DS_Store
|
||||
openapi.yaml
|
||||
filtered-openapi.yaml
|
||||
uv.lock
|
||||
|
||||
13
README.md
13
README.md
@@ -69,9 +69,11 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- Audio Models
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- 3D Models
|
||||
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
@@ -108,7 +110,6 @@ ComfyUI follows a weekly release cycle every Friday, with three interconnected r
|
||||
|
||||
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
||||
- Builds a new release using the latest stable core version
|
||||
- Version numbers match the core release (e.g., Desktop v1.7.0 uses Core v1.7.0)
|
||||
|
||||
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
|
||||
- Weekly frontend updates are merged into the core repository
|
||||
@@ -196,11 +197,11 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2.4```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.3 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
@@ -300,7 +301,7 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt
|
||||
|
||||
### AMD ROCm Tips
|
||||
|
||||
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
|
||||
You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default.
|
||||
|
||||
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
|
||||
|
||||
|
||||
@@ -142,12 +142,15 @@ class PerformanceFeature(enum.Enum):
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
@@ -192,6 +195,13 @@ parser.add_argument("--user-directory", type=is_valid_directory, default=None, h
|
||||
|
||||
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
|
||||
|
||||
parser.add_argument(
|
||||
"--comfy-api-base",
|
||||
type=str,
|
||||
default="https://api.comfy.org",
|
||||
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
||||
)
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
|
||||
@@ -235,7 +235,7 @@ class ComfyNodeABC(ABC):
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
API_NODE: Optional[bool]
|
||||
"""Flags a node as an API node."""
|
||||
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
||||
@@ -1277,6 +1277,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
phi1_fn = lambda t: torch.expm1(t) / t
|
||||
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
|
||||
|
||||
old_sigma_down = None
|
||||
old_denoised = None
|
||||
uncond_denoised = None
|
||||
def post_cfg_function(args):
|
||||
@@ -1304,9 +1305,9 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Second order multistep method in https://arxiv.org/pdf/2308.02157
|
||||
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1])
|
||||
t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
|
||||
h = t_next - t
|
||||
c2 = (t_prev - t) / h
|
||||
c2 = (t_prev - t_old) / h
|
||||
|
||||
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
|
||||
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
|
||||
@@ -1326,6 +1327,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
old_denoised = uncond_denoised
|
||||
else:
|
||||
old_denoised = denoised
|
||||
old_sigma_down = sigma_down
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -466,3 +466,7 @@ class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 1.0188137142395404
|
||||
|
||||
class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
761
comfy/ldm/ace/attention.py
Normal file
761
comfy/ldm/ace/attention.py
Normal file
@@ -0,0 +1,761 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/attention.py
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple, Union, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
kv_heads: Optional[int] = None,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
added_proj_bias: Optional[bool] = True,
|
||||
out_bias: bool = True,
|
||||
scale_qk: bool = True,
|
||||
only_cross_attention: bool = False,
|
||||
eps: float = 1e-5,
|
||||
rescale_output_factor: float = 1.0,
|
||||
residual_connection: bool = False,
|
||||
processor=None,
|
||||
out_dim: int = None,
|
||||
out_context_dim: int = None,
|
||||
context_pre_only=None,
|
||||
pre_only=False,
|
||||
elementwise_affine: bool = True,
|
||||
is_causal: bool = False,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
||||
self.query_dim = query_dim
|
||||
self.use_bias = bias
|
||||
self.is_cross_attention = cross_attention_dim is not None
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
self.residual_connection = residual_connection
|
||||
self.dropout = dropout
|
||||
self.fused_projections = False
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
self.pre_only = pre_only
|
||||
self.is_causal = is_causal
|
||||
|
||||
self.scale_qk = scale_qk
|
||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
# for slice_size > 0 the attention score computation
|
||||
# is split across the batch axis to save memory
|
||||
# You can set slice_size with `set_attention_slice`
|
||||
self.sliceable_head_dim = heads
|
||||
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
||||
raise ValueError(
|
||||
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
||||
)
|
||||
|
||||
self.group_norm = None
|
||||
self.spatial_norm = None
|
||||
|
||||
self.norm_q = None
|
||||
self.norm_k = None
|
||||
|
||||
self.norm_cross = None
|
||||
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
if not self.only_cross_attention:
|
||||
# only relevant for the `AddedKVProcessor` classes
|
||||
self.to_k = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_k = None
|
||||
self.to_v = None
|
||||
|
||||
self.added_proj_bias = added_proj_bias
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
||||
self.add_v_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
||||
if self.context_pre_only is not None:
|
||||
self.add_q_proj = operations.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.add_q_proj = None
|
||||
self.add_k_proj = None
|
||||
self.add_v_proj = None
|
||||
|
||||
if not self.pre_only:
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
else:
|
||||
self.to_out = None
|
||||
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_add_out = None
|
||||
|
||||
self.norm_added_q = None
|
||||
self.norm_added_k = None
|
||||
self.processor = processor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**cross_attention_kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CustomLiteLAProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
|
||||
|
||||
def __init__(self):
|
||||
self.kernel_func = nn.ReLU(inplace=False)
|
||||
self.eps = 1e-15
|
||||
self.pad_val = 1.0
|
||||
|
||||
def apply_rotary_emb(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states_len = hidden_states.shape[1]
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
if encoder_hidden_states is not None:
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
dtype = hidden_states.dtype
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# `context` projections.
|
||||
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
|
||||
if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
# attention
|
||||
if not attn.is_cross_attention:
|
||||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
||||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
||||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
||||
else:
|
||||
query = hidden_states
|
||||
key = encoder_hidden_states
|
||||
value = encoder_hidden_states
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
||||
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
|
||||
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
||||
|
||||
# RoPE需要 [B, H, S, D] 输入
|
||||
# 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
|
||||
query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
|
||||
|
||||
# Apply query and key normalization if needed
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if rotary_freqs_cis is not None:
|
||||
query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
||||
if not attn.is_cross_attention:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
||||
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
||||
|
||||
# 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
|
||||
query = query.permute(0, 1, 3, 2) # [B, H, D, S]
|
||||
|
||||
if attention_mask is not None:
|
||||
# attention_mask: [B, S] -> [B, 1, S, 1]
|
||||
attention_mask = attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S, 1]
|
||||
query = query * attention_mask.permute(0, 1, 3, 2) # [B, H, S, D] * [B, 1, S, 1]
|
||||
if not attn.is_cross_attention:
|
||||
key = key * attention_mask # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
|
||||
value = value * attention_mask.permute(0, 1, 3, 2) # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度
|
||||
|
||||
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
|
||||
encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S_enc, 1]
|
||||
# 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
|
||||
key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
|
||||
value = value * encoder_attention_mask.permute(0, 1, 3, 2) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
|
||||
|
||||
query = self.kernel_func(query)
|
||||
key = self.kernel_func(key)
|
||||
|
||||
query, key, value = query.float(), key.float(), value.float()
|
||||
|
||||
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
|
||||
|
||||
vk = torch.matmul(value, key)
|
||||
|
||||
hidden_states = torch.matmul(vk, query)
|
||||
|
||||
if hidden_states.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.float()
|
||||
|
||||
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
||||
|
||||
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
|
||||
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = encoder_hidden_states.to(dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : hidden_states_len],
|
||||
hidden_states[:, hidden_states_len:],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"):
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if encoder_hidden_states is not None and context_input_ndim == 4:
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if torch.get_autocast_gpu_dtype() == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CustomerAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def apply_rotary_emb(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
residual = hidden_states
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if rotary_freqs_cis is not None:
|
||||
query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
||||
if not attn.is_cross_attention:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
||||
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
||||
|
||||
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
|
||||
# attention_mask: N x S1
|
||||
# encoder_attention_mask: N x S2
|
||||
# cross attention 整合attention_mask和encoder_attention_mask
|
||||
combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
|
||||
attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
|
||||
attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype)
|
||||
|
||||
elif not attn.is_cross_attention and attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = optimized_attention(
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
||||
).to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
|
||||
"""Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
|
||||
if isinstance(x, (list, tuple)):
|
||||
return list(x)
|
||||
return [x for _ in range(repeat_time)]
|
||||
|
||||
|
||||
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
|
||||
"""Return tuple with min_len by repeating element at idx_repeat."""
|
||||
# convert to list first
|
||||
x = val2list(x)
|
||||
|
||||
# repeat elements if necessary
|
||||
if len(x) > 0:
|
||||
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
|
||||
|
||||
return tuple(x)
|
||||
|
||||
|
||||
def t2i_modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
|
||||
if isinstance(kernel_size, tuple):
|
||||
return tuple([get_same_padding(ks) for ks in kernel_size])
|
||||
else:
|
||||
assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
|
||||
return kernel_size // 2
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
padding: Union[int, None] = None,
|
||||
use_bias=False,
|
||||
norm=None,
|
||||
act=None,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
if padding is None:
|
||||
padding = get_same_padding(kernel_size)
|
||||
padding *= dilation
|
||||
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
self.padding = padding
|
||||
self.use_bias = use_bias
|
||||
|
||||
self.conv = operations.Conv1d(
|
||||
in_dim,
|
||||
out_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=use_bias,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
if norm is not None:
|
||||
self.norm = operations.RMSNorm(out_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
else:
|
||||
self.norm = None
|
||||
if act is not None:
|
||||
self.act = nn.SiLU(inplace=True)
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(x)
|
||||
if self.norm:
|
||||
x = self.norm(x)
|
||||
if self.act:
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class GLUMBConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
out_feature=None,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding: Union[int, None] = None,
|
||||
use_bias=False,
|
||||
norm=(None, None, None),
|
||||
act=("silu", "silu", None),
|
||||
dilation=1,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
out_feature = out_feature or in_features
|
||||
super().__init__()
|
||||
use_bias = val2tuple(use_bias, 3)
|
||||
norm = val2tuple(norm, 3)
|
||||
act = val2tuple(act, 3)
|
||||
|
||||
self.glu_act = nn.SiLU(inplace=False)
|
||||
self.inverted_conv = ConvLayer(
|
||||
in_features,
|
||||
hidden_features * 2,
|
||||
1,
|
||||
use_bias=use_bias[0],
|
||||
norm=norm[0],
|
||||
act=act[0],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.depth_conv = ConvLayer(
|
||||
hidden_features * 2,
|
||||
hidden_features * 2,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
groups=hidden_features * 2,
|
||||
padding=padding,
|
||||
use_bias=use_bias[1],
|
||||
norm=norm[1],
|
||||
act=None,
|
||||
dilation=dilation,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.point_conv = ConvLayer(
|
||||
hidden_features,
|
||||
out_feature,
|
||||
1,
|
||||
use_bias=use_bias[2],
|
||||
norm=norm[2],
|
||||
act=act[2],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.transpose(1, 2)
|
||||
x = self.inverted_conv(x)
|
||||
x = self.depth_conv(x)
|
||||
|
||||
x, gate = torch.chunk(x, 2, dim=1)
|
||||
gate = self.glu_act(gate)
|
||||
x = x * gate
|
||||
|
||||
x = self.point_conv(x)
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LinearTransformerBlock(nn.Module):
|
||||
"""
|
||||
A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
use_adaln_single=True,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=None,
|
||||
context_pre_only=False,
|
||||
mlp_ratio=4.0,
|
||||
add_cross_attention=False,
|
||||
add_cross_attention_dim=None,
|
||||
qk_norm=None,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = operations.RMSNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
qk_norm=qk_norm,
|
||||
processor=CustomLiteLAProcessor2_0(),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.add_cross_attention = add_cross_attention
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
if add_cross_attention and add_cross_attention_dim is not None:
|
||||
self.cross_attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=add_cross_attention_dim,
|
||||
added_kv_proj_dim=add_cross_attention_dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=context_pre_only,
|
||||
bias=True,
|
||||
qk_norm=qk_norm,
|
||||
processor=CustomerAttnProcessor2_0(),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.norm2 = operations.RMSNorm(dim, 1e-06, elementwise_affine=False)
|
||||
|
||||
self.ff = GLUMBConv(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
use_bias=(True, True, False),
|
||||
norm=(None, None, None),
|
||||
act=("silu", "silu", None),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.use_adaln_single = use_adaln_single
|
||||
if use_adaln_single:
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: torch.FloatTensor = None,
|
||||
encoder_attention_mask: torch.FloatTensor = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
temb: torch.FloatTensor = None,
|
||||
):
|
||||
|
||||
N = hidden_states.shape[0]
|
||||
|
||||
# step 1: AdaLN single
|
||||
if self.use_adaln_single:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
comfy.model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
if self.use_adaln_single:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
# step 2: attention
|
||||
if not self.add_cross_attention:
|
||||
attn_output, encoder_hidden_states = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
)
|
||||
else:
|
||||
attn_output, _ = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=None,
|
||||
)
|
||||
|
||||
if self.use_adaln_single:
|
||||
attn_output = gate_msa * attn_output
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
if self.add_cross_attention:
|
||||
attn_output = self.cross_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# step 3: add norm
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
if self.use_adaln_single:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
# step 4: feed forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
if self.use_adaln_single:
|
||||
ff_output = gate_mlp * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
|
||||
return hidden_states
|
||||
1067
comfy/ldm/ace/lyric_encoder.py
Normal file
1067
comfy/ldm/ace/lyric_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
385
comfy/ldm/ace/model.py
Normal file
385
comfy/ldm/ace/model.py
Normal file
@@ -0,0 +1,385 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/ace_step_transformer.py
|
||||
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, List, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from .attention import LinearTransformerBlock, t2i_modulate
|
||||
from .lyric_encoder import ConformerEncoder as LyricEncoder
|
||||
|
||||
|
||||
def cross_norm(hidden_states, controlnet_input):
|
||||
# input N x T x c
|
||||
mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
|
||||
mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
|
||||
controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
|
||||
return controlnet_input
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
|
||||
class Qwen2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, dtype=None, device=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
class T2IFinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of Sana.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(2, hidden_size, dtype=dtype, device=device))
|
||||
self.out_channels = out_channels
|
||||
self.patch_size = patch_size
|
||||
|
||||
def unpatchfy(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
width: int,
|
||||
):
|
||||
# 4 unpatchify
|
||||
new_height, new_width = 1, hidden_states.size(1)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
|
||||
).contiguous()
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
|
||||
).contiguous()
|
||||
if width > new_width:
|
||||
output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
|
||||
elif width < new_width:
|
||||
output = output[:, :, :, :width]
|
||||
return output
|
||||
|
||||
def forward(self, x, t, output_length):
|
||||
shift, scale = (comfy.model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
|
||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
# unpatchify
|
||||
output = self.unpatchfy(x, output_length)
|
||||
return output
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
height=16,
|
||||
width=4096,
|
||||
patch_size=(16, 1),
|
||||
in_channels=8,
|
||||
embed_dim=1152,
|
||||
bias=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
patch_size_h, patch_size_w = patch_size
|
||||
self.early_conv_layers = nn.Sequential(
|
||||
operations.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias, dtype=dtype, device=device),
|
||||
operations.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True, dtype=dtype, device=device),
|
||||
operations.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias, dtype=dtype, device=device)
|
||||
)
|
||||
self.patch_size = patch_size
|
||||
self.height, self.width = height // patch_size_h, width // patch_size_w
|
||||
self.base_size = self.width
|
||||
|
||||
def forward(self, latent):
|
||||
# early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
|
||||
latent = self.early_conv_layers(latent)
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
return latent
|
||||
|
||||
|
||||
class ACEStepTransformer2DModel(nn.Module):
|
||||
# _supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: Optional[int] = 8,
|
||||
num_layers: int = 28,
|
||||
inner_dim: int = 1536,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 24,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_channels: int = 8,
|
||||
max_position: int = 32768,
|
||||
rope_theta: float = 1000000.0,
|
||||
speaker_embedding_dim: int = 512,
|
||||
text_embedding_dim: int = 768,
|
||||
ssl_encoder_depths: List[int] = [9, 9],
|
||||
ssl_names: List[str] = ["mert", "m-hubert"],
|
||||
ssl_latent_dims: List[int] = [1024, 768],
|
||||
lyric_encoder_vocab_size: int = 6681,
|
||||
lyric_hidden_size: int = 1024,
|
||||
patch_size: List[int] = [16, 1],
|
||||
max_height: int = 16,
|
||||
max_width: int = 4096,
|
||||
audio_model=None,
|
||||
dtype=None, device=None, operations=None
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dtype = dtype
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
self.out_channels = out_channels
|
||||
self.max_position = max_position
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.rotary_emb = Qwen2RotaryEmbedding(
|
||||
dim=self.attention_head_dim,
|
||||
max_position_embeddings=self.max_position,
|
||||
base=self.rope_theta,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.num_layers = num_layers
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
LinearTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
add_cross_attention=True,
|
||||
add_cross_attention_dim=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.t_block = nn.Sequential(nn.SiLU(), operations.Linear(self.inner_dim, 6 * self.inner_dim, bias=True, dtype=dtype, device=device))
|
||||
|
||||
# speaker
|
||||
self.speaker_embedder = operations.Linear(speaker_embedding_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
# genre
|
||||
self.genre_embedder = operations.Linear(text_embedding_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
# lyric
|
||||
self.lyric_embs = operations.Embedding(lyric_encoder_vocab_size, lyric_hidden_size, dtype=dtype, device=device)
|
||||
self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0, dtype=dtype, device=device, operations=operations)
|
||||
self.lyric_proj = operations.Linear(lyric_hidden_size, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
projector_dim = 2 * self.inner_dim
|
||||
|
||||
self.projectors = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
operations.Linear(self.inner_dim, projector_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(projector_dim, projector_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(projector_dim, ssl_dim, dtype=dtype, device=device),
|
||||
) for ssl_dim in ssl_latent_dims
|
||||
])
|
||||
|
||||
self.proj_in = PatchEmbed(
|
||||
height=max_height,
|
||||
width=max_width,
|
||||
patch_size=patch_size,
|
||||
embed_dim=self.inner_dim,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_lyric_encoder(
|
||||
self,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
out_dtype=None,
|
||||
):
|
||||
# N x T x D
|
||||
lyric_embs = self.lyric_embs(lyric_token_idx, out_dtype=out_dtype)
|
||||
prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
|
||||
prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
|
||||
return prompt_prenet_out
|
||||
|
||||
def encode(
|
||||
self,
|
||||
encoder_text_hidden_states: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
lyrics_strength=1.0,
|
||||
):
|
||||
|
||||
bs = encoder_text_hidden_states.shape[0]
|
||||
device = encoder_text_hidden_states.device
|
||||
|
||||
# speaker embedding
|
||||
encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
|
||||
|
||||
# genre embedding
|
||||
encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
|
||||
|
||||
# lyric
|
||||
encoder_lyric_hidden_states = self.forward_lyric_encoder(
|
||||
lyric_token_idx=lyric_token_idx,
|
||||
lyric_mask=lyric_mask,
|
||||
out_dtype=encoder_text_hidden_states.dtype,
|
||||
)
|
||||
|
||||
encoder_lyric_hidden_states *= lyrics_strength
|
||||
|
||||
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
|
||||
|
||||
encoder_hidden_mask = None
|
||||
if text_attention_mask is not None:
|
||||
speaker_mask = torch.ones(bs, 1, device=device)
|
||||
encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
|
||||
|
||||
return encoder_hidden_states, encoder_hidden_mask
|
||||
|
||||
def decode(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_mask: torch.Tensor,
|
||||
timestep: Optional[torch.Tensor],
|
||||
output_length: int = 0,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
):
|
||||
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
||||
temb = self.t_block(embedded_timestep)
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# controlnet logic
|
||||
if block_controlnet_hidden_states is not None:
|
||||
control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
|
||||
hidden_states = hidden_states + control_condi * controlnet_scale
|
||||
|
||||
# inner_hidden_states = []
|
||||
|
||||
rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
|
||||
encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_hidden_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||
temb=temb,
|
||||
)
|
||||
|
||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
attention_mask=None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
lyrics_strength=1.0,
|
||||
**kwargs
|
||||
):
|
||||
hidden_states = x
|
||||
encoder_text_hidden_states = context
|
||||
encoder_hidden_states, encoder_hidden_mask = self.encode(
|
||||
encoder_text_hidden_states=encoder_text_hidden_states,
|
||||
text_attention_mask=text_attention_mask,
|
||||
speaker_embeds=speaker_embeds,
|
||||
lyric_token_idx=lyric_token_idx,
|
||||
lyric_mask=lyric_mask,
|
||||
lyrics_strength=lyrics_strength,
|
||||
)
|
||||
|
||||
output_length = hidden_states.shape[-1]
|
||||
|
||||
output = self.decode(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_mask=encoder_hidden_mask,
|
||||
timestep=timestep,
|
||||
output_length=output_length,
|
||||
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
||||
controlnet_scale=controlnet_scale,
|
||||
)
|
||||
|
||||
return output
|
||||
644
comfy/ldm/ace/vae/autoencoder_dc.py
Normal file
644
comfy/ldm/ace/vae/autoencoder_dc.py
Normal file
@@ -0,0 +1,644 @@
|
||||
# Rewritten from diffusers
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple, Union
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class RMSNorm(ops.RMSNorm):
|
||||
def __init__(self, dim, eps=1e-5, elementwise_affine=True, bias=False):
|
||||
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
if elementwise_affine:
|
||||
self.bias = nn.Parameter(torch.empty(dim)) if bias else None
|
||||
|
||||
def forward(self, x):
|
||||
x = super().forward(x)
|
||||
if self.elementwise_affine:
|
||||
if self.bias is not None:
|
||||
x = x + comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device)
|
||||
return x
|
||||
|
||||
|
||||
def get_normalization(norm_type, num_features, num_groups=32, eps=1e-5):
|
||||
if norm_type == "batch_norm":
|
||||
return nn.BatchNorm2d(num_features)
|
||||
elif norm_type == "group_norm":
|
||||
return ops.GroupNorm(num_groups, num_features)
|
||||
elif norm_type == "layer_norm":
|
||||
return ops.LayerNorm(num_features)
|
||||
elif norm_type == "rms_norm":
|
||||
return RMSNorm(num_features, eps=eps, elementwise_affine=True, bias=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown normalization type: {norm_type}")
|
||||
|
||||
|
||||
def get_activation(activation_type):
|
||||
if activation_type == "relu":
|
||||
return nn.ReLU()
|
||||
elif activation_type == "relu6":
|
||||
return nn.ReLU6()
|
||||
elif activation_type == "silu":
|
||||
return nn.SiLU()
|
||||
elif activation_type == "leaky_relu":
|
||||
return nn.LeakyReLU(0.2)
|
||||
else:
|
||||
raise ValueError(f"Unknown activation type: {activation_type}")
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
norm_type: str = "batch_norm",
|
||||
act_fn: str = "relu6",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
|
||||
self.conv1 = ops.Conv2d(in_channels, in_channels, 3, 1, 1)
|
||||
self.conv2 = ops.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
|
||||
self.norm = get_normalization(norm_type, out_channels)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
else:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states + residual
|
||||
|
||||
class SanaMultiscaleAttentionProjection(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_attention_heads: int,
|
||||
kernel_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
channels = 3 * in_channels
|
||||
self.proj_in = ops.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=channels,
|
||||
bias=False,
|
||||
)
|
||||
self.proj_out = ops.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class SanaMultiscaleLinearAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_attention_heads: int = None,
|
||||
attention_head_dim: int = 8,
|
||||
mult: float = 1.0,
|
||||
norm_type: str = "batch_norm",
|
||||
kernel_sizes: tuple = (5,),
|
||||
eps: float = 1e-15,
|
||||
residual_connection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
num_attention_heads = (
|
||||
int(in_channels // attention_head_dim * mult)
|
||||
if num_attention_heads is None
|
||||
else num_attention_heads
|
||||
)
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.to_q = ops.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_k = ops.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_v = ops.Linear(in_channels, inner_dim, bias=False)
|
||||
|
||||
self.to_qkv_multiscale = nn.ModuleList()
|
||||
for kernel_size in kernel_sizes:
|
||||
self.to_qkv_multiscale.append(
|
||||
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
|
||||
)
|
||||
|
||||
self.nonlinearity = nn.ReLU()
|
||||
self.to_out = ops.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
|
||||
self.norm_out = get_normalization(norm_type, out_channels)
|
||||
|
||||
def apply_linear_attention(self, query, key, value):
|
||||
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1)
|
||||
scores = torch.matmul(value, key.transpose(-1, -2))
|
||||
hidden_states = torch.matmul(scores, query)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
||||
return hidden_states
|
||||
|
||||
def apply_quadratic_attention(self, query, key, value):
|
||||
scores = torch.matmul(key.transpose(-1, -2), query)
|
||||
scores = scores.to(dtype=torch.float32)
|
||||
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
|
||||
hidden_states = torch.matmul(value, scores.to(value.dtype))
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states):
|
||||
height, width = hidden_states.shape[-2:]
|
||||
if height * width > self.attention_head_dim:
|
||||
use_linear_attention = True
|
||||
else:
|
||||
use_linear_attention = False
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
batch_size, _, height, width = list(hidden_states.size())
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
hidden_states = hidden_states.movedim(1, -1)
|
||||
query = self.to_q(hidden_states)
|
||||
key = self.to_k(hidden_states)
|
||||
value = self.to_v(hidden_states)
|
||||
hidden_states = torch.cat([query, key, value], dim=3)
|
||||
hidden_states = hidden_states.movedim(-1, 1)
|
||||
|
||||
multi_scale_qkv = [hidden_states]
|
||||
for block in self.to_qkv_multiscale:
|
||||
multi_scale_qkv.append(block(hidden_states))
|
||||
|
||||
hidden_states = torch.cat(multi_scale_qkv, dim=1)
|
||||
|
||||
if use_linear_attention:
|
||||
# for linear attention upcast hidden_states to float32
|
||||
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch_size, -1, 3 * self.attention_head_dim, height * width)
|
||||
|
||||
query, key, value = hidden_states.chunk(3, dim=2)
|
||||
query = self.nonlinearity(query)
|
||||
key = self.nonlinearity(key)
|
||||
|
||||
if use_linear_attention:
|
||||
hidden_states = self.apply_linear_attention(query, key, value)
|
||||
hidden_states = hidden_states.to(dtype=original_dtype)
|
||||
else:
|
||||
hidden_states = self.apply_quadratic_attention(query, key, value)
|
||||
|
||||
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
|
||||
hidden_states = self.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
else:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
if self.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EfficientViTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
mult: float = 1.0,
|
||||
attention_head_dim: int = 32,
|
||||
qkv_multiscales: tuple = (5,),
|
||||
norm_type: str = "batch_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attn = SanaMultiscaleLinearAttention(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
mult=mult,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type,
|
||||
kernel_sizes=qkv_multiscales,
|
||||
residual_connection=True,
|
||||
)
|
||||
|
||||
self.conv_out = GLUMBConv(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
norm_type="rms_norm",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.attn(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class GLUMBConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
expand_ratio: float = 4,
|
||||
norm_type: str = None,
|
||||
residual_connection: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_channels = int(expand_ratio * in_channels)
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
self.nonlinearity = nn.SiLU()
|
||||
self.conv_inverted = ops.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
|
||||
self.conv_depth = ops.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
|
||||
self.conv_point = ops.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
|
||||
|
||||
self.norm = None
|
||||
if norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.residual_connection:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.conv_inverted(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv_depth(hidden_states)
|
||||
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
|
||||
hidden_states = hidden_states * self.nonlinearity(gate)
|
||||
|
||||
hidden_states = self.conv_point(hidden_states)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_block(
|
||||
block_type: str,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
attention_head_dim: int,
|
||||
norm_type: str,
|
||||
act_fn: str,
|
||||
qkv_mutliscales: tuple = (),
|
||||
):
|
||||
if block_type == "ResBlock":
|
||||
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||
elif block_type == "EfficientViTBlock":
|
||||
block = EfficientViTBlock(
|
||||
in_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type,
|
||||
qkv_multiscales=qkv_mutliscales
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Block with {block_type=} is not supported.")
|
||||
|
||||
return block
|
||||
|
||||
|
||||
class DCDownBlock2d(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.downsample = downsample
|
||||
self.factor = 2
|
||||
self.stride = 1 if downsample else 2
|
||||
self.group_size = in_channels * self.factor**2 // out_channels
|
||||
self.shortcut = shortcut
|
||||
|
||||
out_ratio = self.factor**2
|
||||
if downsample:
|
||||
assert out_channels % out_ratio == 0
|
||||
out_channels = out_channels // out_ratio
|
||||
|
||||
self.conv = ops.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=self.stride,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(hidden_states)
|
||||
if self.downsample:
|
||||
x = F.pixel_unshuffle(x, self.factor)
|
||||
|
||||
if self.shortcut:
|
||||
y = F.pixel_unshuffle(hidden_states, self.factor)
|
||||
y = y.unflatten(1, (-1, self.group_size))
|
||||
y = y.mean(dim=2)
|
||||
hidden_states = x + y
|
||||
else:
|
||||
hidden_states = x
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DCUpBlock2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
interpolate: bool = False,
|
||||
shortcut: bool = True,
|
||||
interpolation_mode: str = "nearest",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.interpolate = interpolate
|
||||
self.interpolation_mode = interpolation_mode
|
||||
self.shortcut = shortcut
|
||||
self.factor = 2
|
||||
self.repeats = out_channels * self.factor**2 // in_channels
|
||||
|
||||
out_ratio = self.factor**2
|
||||
if not interpolate:
|
||||
out_channels = out_channels * out_ratio
|
||||
|
||||
self.conv = ops.Conv2d(in_channels, out_channels, 3, 1, 1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.interpolate:
|
||||
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = self.conv(hidden_states)
|
||||
x = F.pixel_shuffle(x, self.factor)
|
||||
|
||||
if self.shortcut:
|
||||
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
|
||||
y = F.pixel_shuffle(y, self.factor)
|
||||
hidden_states = x + y
|
||||
else:
|
||||
hidden_states = x
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: str or tuple = "ResBlock",
|
||||
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
out_shortcut: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_blocks = len(block_out_channels)
|
||||
|
||||
if isinstance(block_type, str):
|
||||
block_type = (block_type,) * num_blocks
|
||||
|
||||
if layers_per_block[0] > 0:
|
||||
self.conv_in = ops.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
else:
|
||||
self.conv_in = DCDownBlock2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
|
||||
downsample=downsample_block_type == "pixel_unshuffle",
|
||||
shortcut=False,
|
||||
)
|
||||
|
||||
down_blocks = []
|
||||
for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)):
|
||||
down_block_list = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
block = get_block(
|
||||
block_type[i],
|
||||
out_channel,
|
||||
out_channel,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type="rms_norm",
|
||||
act_fn="silu",
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
)
|
||||
down_block_list.append(block)
|
||||
|
||||
if i < num_blocks - 1 and num_layers > 0:
|
||||
downsample_block = DCDownBlock2d(
|
||||
in_channels=out_channel,
|
||||
out_channels=block_out_channels[i + 1],
|
||||
downsample=downsample_block_type == "pixel_unshuffle",
|
||||
shortcut=True,
|
||||
)
|
||||
down_block_list.append(downsample_block)
|
||||
|
||||
down_blocks.append(nn.Sequential(*down_block_list))
|
||||
|
||||
self.down_blocks = nn.ModuleList(down_blocks)
|
||||
|
||||
self.conv_out = ops.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1)
|
||||
|
||||
self.out_shortcut = out_shortcut
|
||||
if out_shortcut:
|
||||
self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
|
||||
if self.out_shortcut:
|
||||
x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size))
|
||||
x = x.mean(dim=2)
|
||||
hidden_states = self.conv_out(hidden_states) + x
|
||||
else:
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: str or tuple = "ResBlock",
|
||||
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
|
||||
norm_type: str or tuple = "rms_norm",
|
||||
act_fn: str or tuple = "silu",
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
in_shortcut: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_blocks = len(block_out_channels)
|
||||
|
||||
if isinstance(block_type, str):
|
||||
block_type = (block_type,) * num_blocks
|
||||
if isinstance(norm_type, str):
|
||||
norm_type = (norm_type,) * num_blocks
|
||||
if isinstance(act_fn, str):
|
||||
act_fn = (act_fn,) * num_blocks
|
||||
|
||||
self.conv_in = ops.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1)
|
||||
|
||||
self.in_shortcut = in_shortcut
|
||||
if in_shortcut:
|
||||
self.in_shortcut_repeats = block_out_channels[-1] // latent_channels
|
||||
|
||||
up_blocks = []
|
||||
for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
|
||||
up_block_list = []
|
||||
|
||||
if i < num_blocks - 1 and num_layers > 0:
|
||||
upsample_block = DCUpBlock2d(
|
||||
block_out_channels[i + 1],
|
||||
out_channel,
|
||||
interpolate=upsample_block_type == "interpolate",
|
||||
shortcut=True,
|
||||
)
|
||||
up_block_list.append(upsample_block)
|
||||
|
||||
for _ in range(num_layers):
|
||||
block = get_block(
|
||||
block_type[i],
|
||||
out_channel,
|
||||
out_channel,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type[i],
|
||||
act_fn=act_fn[i],
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
)
|
||||
up_block_list.append(block)
|
||||
|
||||
up_blocks.insert(0, nn.Sequential(*up_block_list))
|
||||
|
||||
self.up_blocks = nn.ModuleList(up_blocks)
|
||||
|
||||
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
|
||||
|
||||
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
|
||||
self.conv_act = nn.ReLU()
|
||||
self.conv_out = None
|
||||
|
||||
if layers_per_block[0] > 0:
|
||||
self.conv_out = ops.Conv2d(channels, in_channels, 3, 1, 1)
|
||||
else:
|
||||
self.conv_out = DCUpBlock2d(
|
||||
channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.in_shortcut:
|
||||
x = hidden_states.repeat_interleave(
|
||||
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
|
||||
)
|
||||
hidden_states = self.conv_in(hidden_states) + x
|
||||
else:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for up_block in reversed(self.up_blocks):
|
||||
hidden_states = up_block(hidden_states)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderDC(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 2,
|
||||
latent_channels: int = 8,
|
||||
attention_head_dim: int = 32,
|
||||
encoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
|
||||
decoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
|
||||
encoder_layers_per_block: Tuple[int] = (2, 2, 3, 3),
|
||||
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3),
|
||||
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
|
||||
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
|
||||
upsample_block_type: str = "interpolate",
|
||||
downsample_block_type: str = "Conv",
|
||||
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
|
||||
decoder_act_fns: Union[str, Tuple[str]] = "silu",
|
||||
scaling_factor: float = 0.41407,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
latent_channels=latent_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
block_type=encoder_block_types,
|
||||
block_out_channels=encoder_block_out_channels,
|
||||
layers_per_block=encoder_layers_per_block,
|
||||
qkv_multiscales=encoder_qkv_multiscales,
|
||||
downsample_block_type=downsample_block_type,
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
in_channels=in_channels,
|
||||
latent_channels=latent_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
block_type=decoder_block_types,
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
layers_per_block=decoder_layers_per_block,
|
||||
qkv_multiscales=decoder_qkv_multiscales,
|
||||
norm_type=decoder_norm_types,
|
||||
act_fn=decoder_act_fns,
|
||||
upsample_block_type=upsample_block_type,
|
||||
)
|
||||
|
||||
self.scaling_factor = scaling_factor
|
||||
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Internal encoding function."""
|
||||
encoded = self.encoder(x)
|
||||
return encoded * self.scaling_factor
|
||||
|
||||
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
# Scale the latents back
|
||||
z = z / self.scaling_factor
|
||||
decoded = self.decoder(z)
|
||||
return decoded
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
z = self.encode(x)
|
||||
return self.decode(z)
|
||||
|
||||
109
comfy/ldm/ace/vae/music_dcae_pipeline.py
Normal file
109
comfy/ldm/ace/vae/music_dcae_pipeline.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py
|
||||
import torch
|
||||
from .autoencoder_dc import AutoencoderDC
|
||||
import logging
|
||||
try:
|
||||
import torchaudio
|
||||
except:
|
||||
logging.warning("torchaudio missing, ACE model will be broken")
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
from .music_vocoder import ADaMoSHiFiGANV1
|
||||
|
||||
|
||||
class MusicDCAE(torch.nn.Module):
|
||||
def __init__(self, source_sample_rate=None, dcae_config={}, vocoder_config={}):
|
||||
super(MusicDCAE, self).__init__()
|
||||
|
||||
self.dcae = AutoencoderDC(**dcae_config)
|
||||
self.vocoder = ADaMoSHiFiGANV1(**vocoder_config)
|
||||
|
||||
if source_sample_rate is None:
|
||||
self.source_sample_rate = 48000
|
||||
else:
|
||||
self.source_sample_rate = source_sample_rate
|
||||
|
||||
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Normalize(0.5, 0.5),
|
||||
])
|
||||
self.min_mel_value = -11.0
|
||||
self.max_mel_value = 3.0
|
||||
self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
|
||||
self.mel_chunk_size = 1024
|
||||
self.time_dimention_multiple = 8
|
||||
self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
|
||||
self.scale_factor = 0.1786
|
||||
self.shift_factor = -1.9091
|
||||
|
||||
def load_audio(self, audio_path):
|
||||
audio, sr = torchaudio.load(audio_path)
|
||||
return audio, sr
|
||||
|
||||
def forward_mel(self, audios):
|
||||
mels = []
|
||||
for i in range(len(audios)):
|
||||
image = self.vocoder.mel_transform(audios[i])
|
||||
mels.append(image)
|
||||
mels = torch.stack(mels)
|
||||
return mels
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, audios, audio_lengths=None, sr=None):
|
||||
if audio_lengths is None:
|
||||
audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
|
||||
audio_lengths = audio_lengths.to(audios.device)
|
||||
|
||||
if sr is None:
|
||||
sr = self.source_sample_rate
|
||||
|
||||
if sr != 44100:
|
||||
audios = torchaudio.functional.resample(audios, sr, 44100)
|
||||
|
||||
max_audio_len = audios.shape[-1]
|
||||
if max_audio_len % (8 * 512) != 0:
|
||||
audios = torch.nn.functional.pad(audios, (0, 8 * 512 - max_audio_len % (8 * 512)))
|
||||
|
||||
mels = self.forward_mel(audios)
|
||||
mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
|
||||
mels = self.transform(mels)
|
||||
latents = []
|
||||
for mel in mels:
|
||||
latent = self.dcae.encoder(mel.unsqueeze(0))
|
||||
latents.append(latent)
|
||||
latents = torch.cat(latents, dim=0)
|
||||
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
|
||||
latents = (latents - self.shift_factor) * self.scale_factor
|
||||
return latents
|
||||
# return latents, latent_lengths
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, latents, audio_lengths=None, sr=None):
|
||||
latents = latents / self.scale_factor + self.shift_factor
|
||||
|
||||
pred_wavs = []
|
||||
|
||||
for latent in latents:
|
||||
mels = self.dcae.decoder(latent.unsqueeze(0))
|
||||
mels = mels * 0.5 + 0.5
|
||||
mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
|
||||
wav = self.vocoder.decode(mels[0]).squeeze(1)
|
||||
|
||||
if sr is not None:
|
||||
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
|
||||
wav = torchaudio.functional.resample(wav, 44100, sr)
|
||||
# wav = resampler(wav)
|
||||
else:
|
||||
sr = 44100
|
||||
pred_wavs.append(wav)
|
||||
|
||||
if audio_lengths is not None:
|
||||
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
|
||||
return torch.stack(pred_wavs)
|
||||
# return sr, pred_wavs
|
||||
|
||||
def forward(self, audios, audio_lengths=None, sr=None):
|
||||
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
|
||||
sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
|
||||
return sr, pred_wavs, latents, latent_lengths
|
||||
113
comfy/ldm/ace/vae/music_log_mel.py
Executable file
113
comfy/ldm/ace/vae/music_log_mel.py
Executable file
@@ -0,0 +1,113 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
import logging
|
||||
try:
|
||||
from torchaudio.transforms import MelScale
|
||||
except:
|
||||
logging.warning("torchaudio missing, ACE model will be broken")
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
class LinearSpectrogram(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
center=False,
|
||||
mode="pow2_sqrt",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.mode = mode
|
||||
|
||||
self.register_buffer("window", torch.hann_window(win_length))
|
||||
|
||||
def forward(self, y: Tensor) -> Tensor:
|
||||
if y.ndim == 3:
|
||||
y = y.squeeze(1)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(
|
||||
(self.win_length - self.hop_length) // 2,
|
||||
(self.win_length - self.hop_length + 1) // 2,
|
||||
),
|
||||
mode="reflect",
|
||||
).squeeze(1)
|
||||
dtype = y.dtype
|
||||
spec = torch.stft(
|
||||
y.float(),
|
||||
self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=comfy.model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
|
||||
center=self.center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
spec = torch.view_as_real(spec)
|
||||
|
||||
if self.mode == "pow2_sqrt":
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
spec = spec.to(dtype)
|
||||
return spec
|
||||
|
||||
|
||||
class LogMelSpectrogram(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=44100,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
n_mels=128,
|
||||
center=False,
|
||||
f_min=0.0,
|
||||
f_max=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.n_mels = n_mels
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max or sample_rate // 2
|
||||
|
||||
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
|
||||
self.mel_scale = MelScale(
|
||||
self.n_mels,
|
||||
self.sample_rate,
|
||||
self.f_min,
|
||||
self.f_max,
|
||||
self.n_fft // 2 + 1,
|
||||
"slaney",
|
||||
"slaney",
|
||||
)
|
||||
|
||||
def compress(self, x: Tensor) -> Tensor:
|
||||
return torch.log(torch.clamp(x, min=1e-5))
|
||||
|
||||
def decompress(self, x: Tensor) -> Tensor:
|
||||
return torch.exp(x)
|
||||
|
||||
def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
|
||||
linear = self.spectrogram(x)
|
||||
x = self.mel_scale(linear)
|
||||
x = self.compress(x)
|
||||
# print(x.shape)
|
||||
if return_linear:
|
||||
return x, self.compress(linear)
|
||||
|
||||
return x
|
||||
538
comfy/ldm/ace/vae/music_vocoder.py
Executable file
538
comfy/ldm/ace/vae/music_vocoder.py
Executable file
@@ -0,0 +1,538 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_vocoder.py
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from functools import partial
|
||||
from math import prod
|
||||
from typing import Callable, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
|
||||
|
||||
from .music_log_mel import LogMelSpectrogram
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
def drop_path(
|
||||
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
||||
):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (
|
||||
x.ndim - 1
|
||||
) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0 and scale_by_keep:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
|
||||
|
||||
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.scale_by_keep = scale_by_keep
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
||||
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
||||
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
||||
with shape (batch_size, channels, height, width).
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||
self.eps = eps
|
||||
self.data_format = data_format
|
||||
if self.data_format not in ["channels_last", "channels_first"]:
|
||||
raise NotImplementedError
|
||||
self.normalized_shape = (normalized_shape,)
|
||||
|
||||
def forward(self, x):
|
||||
if self.data_format == "channels_last":
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape, comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device), comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device), self.eps
|
||||
)
|
||||
elif self.data_format == "channels_first":
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = comfy.model_management.cast_to(self.weight[:, None], dtype=x.dtype, device=x.device) * x + comfy.model_management.cast_to(self.bias[:, None], dtype=x.dtype, device=x.device)
|
||||
return x
|
||||
|
||||
|
||||
class ConvNeXtBlock(nn.Module):
|
||||
r"""ConvNeXt Block. There are two equivalent implementations:
|
||||
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
||||
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
||||
We use (2) as we find it slightly faster in PyTorch
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
drop_path (float): Stochastic depth rate. Default: 0.0
|
||||
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
|
||||
kernel_size (int): Kernel size for depthwise conv. Default: 7.
|
||||
dilation (int): Dilation for depthwise conv. Default: 1.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
drop_path: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-6,
|
||||
mlp_ratio: float = 4.0,
|
||||
kernel_size: int = 7,
|
||||
dilation: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dwconv = ops.Conv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=int(dilation * (kernel_size - 1) / 2),
|
||||
groups=dim,
|
||||
) # depthwise conv
|
||||
self.norm = LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = ops.Linear(
|
||||
dim, int(mlp_ratio * dim)
|
||||
) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = ops.Linear(int(mlp_ratio * dim), dim)
|
||||
self.gamma = (
|
||||
nn.Parameter(torch.empty((dim)), requires_grad=False)
|
||||
if layer_scale_init_value > 0
|
||||
else None
|
||||
)
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
def forward(self, x, apply_residual: bool = True):
|
||||
input = x
|
||||
|
||||
x = self.dwconv(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pwconv2(x)
|
||||
|
||||
if self.gamma is not None:
|
||||
x = comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) * x
|
||||
|
||||
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
||||
x = self.drop_path(x)
|
||||
|
||||
if apply_residual:
|
||||
x = input + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ParallelConvNeXtBlock(nn.Module):
|
||||
def __init__(self, kernel_sizes: List[int], *args, **kwargs):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
|
||||
for kernel_size in kernel_sizes
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack(
|
||||
[block(x, apply_residual=False) for block in self.blocks] + [x],
|
||||
dim=1,
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
class ConvNeXtEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channels=3,
|
||||
depths=[3, 3, 9, 3],
|
||||
dims=[96, 192, 384, 768],
|
||||
drop_path_rate=0.0,
|
||||
layer_scale_init_value=1e-6,
|
||||
kernel_sizes: Tuple[int] = (7,),
|
||||
):
|
||||
super().__init__()
|
||||
assert len(depths) == len(dims)
|
||||
|
||||
self.channel_layers = nn.ModuleList()
|
||||
stem = nn.Sequential(
|
||||
ops.Conv1d(
|
||||
input_channels,
|
||||
dims[0],
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
padding_mode="replicate",
|
||||
),
|
||||
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
||||
)
|
||||
self.channel_layers.append(stem)
|
||||
|
||||
for i in range(len(depths) - 1):
|
||||
mid_layer = nn.Sequential(
|
||||
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
||||
ops.Conv1d(dims[i], dims[i + 1], kernel_size=1),
|
||||
)
|
||||
self.channel_layers.append(mid_layer)
|
||||
|
||||
block_fn = (
|
||||
partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
|
||||
if len(kernel_sizes) == 1
|
||||
else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
|
||||
)
|
||||
|
||||
self.stages = nn.ModuleList()
|
||||
drop_path_rates = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
]
|
||||
|
||||
cur = 0
|
||||
for i in range(len(depths)):
|
||||
stage = nn.Sequential(
|
||||
*[
|
||||
block_fn(
|
||||
dim=dims[i],
|
||||
drop_path=drop_path_rates[cur + j],
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
)
|
||||
for j in range(depths[i])
|
||||
]
|
||||
)
|
||||
self.stages.append(stage)
|
||||
cur += depths[i]
|
||||
|
||||
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
for channel_layer, stage in zip(self.channel_layers, self.stages):
|
||||
x = channel_layer(x)
|
||||
x = stage(x)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return (kernel_size * dilation - dilation) // 2
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super().__init__()
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.silu(x)
|
||||
xt = c1(xt)
|
||||
xt = F.silu(xt)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for conv in self.convs1:
|
||||
remove_weight_norm(conv)
|
||||
for conv in self.convs2:
|
||||
remove_weight_norm(conv)
|
||||
|
||||
|
||||
class HiFiGANGenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
hop_length: int = 512,
|
||||
upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
|
||||
upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
|
||||
resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
|
||||
resblock_dilation_sizes: Tuple[Tuple[int]] = (
|
||||
(1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
||||
num_mels: int = 128,
|
||||
upsample_initial_channel: int = 512,
|
||||
use_template: bool = True,
|
||||
pre_conv_kernel_size: int = 7,
|
||||
post_conv_kernel_size: int = 7,
|
||||
post_activation: Callable = partial(nn.SiLU, inplace=True),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
prod(upsample_rates) == hop_length
|
||||
), f"hop_length must be {prod(upsample_rates)}"
|
||||
|
||||
self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
num_mels,
|
||||
upsample_initial_channel,
|
||||
pre_conv_kernel_size,
|
||||
1,
|
||||
padding=get_padding(pre_conv_kernel_size),
|
||||
)
|
||||
)
|
||||
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
|
||||
self.noise_convs = nn.ModuleList()
|
||||
self.use_template = use_template
|
||||
self.ups = nn.ModuleList()
|
||||
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups.append(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not use_template:
|
||||
continue
|
||||
|
||||
if i + 1 < len(upsample_rates):
|
||||
stride_f0 = np.prod(upsample_rates[i + 1:])
|
||||
self.noise_convs.append(
|
||||
ops.Conv1d(
|
||||
1,
|
||||
c_cur,
|
||||
kernel_size=stride_f0 * 2,
|
||||
stride=stride_f0,
|
||||
padding=stride_f0 // 2,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.noise_convs.append(ops.Conv1d(1, c_cur, kernel_size=1))
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||
self.resblocks.append(ResBlock1(ch, k, d))
|
||||
|
||||
self.activation_post = post_activation()
|
||||
self.conv_post = torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
ch,
|
||||
1,
|
||||
post_conv_kernel_size,
|
||||
1,
|
||||
padding=get_padding(post_conv_kernel_size),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x, template=None):
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.silu(x, inplace=True)
|
||||
x = self.ups[i](x)
|
||||
|
||||
if self.use_template:
|
||||
x = x + self.noise_convs[i](template)
|
||||
|
||||
xs = None
|
||||
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
|
||||
x = xs / self.num_kernels
|
||||
|
||||
x = self.activation_post(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for up in self.ups:
|
||||
remove_weight_norm(up)
|
||||
for block in self.resblocks:
|
||||
block.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
|
||||
class ADaMoSHiFiGANV1(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channels: int = 128,
|
||||
depths: List[int] = [3, 3, 9, 3],
|
||||
dims: List[int] = [128, 256, 384, 512],
|
||||
drop_path_rate: float = 0.0,
|
||||
kernel_sizes: Tuple[int] = (7,),
|
||||
upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
|
||||
upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
|
||||
resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
|
||||
resblock_dilation_sizes: Tuple[Tuple[int]] = (
|
||||
(1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
||||
num_mels: int = 512,
|
||||
upsample_initial_channel: int = 1024,
|
||||
use_template: bool = False,
|
||||
pre_conv_kernel_size: int = 13,
|
||||
post_conv_kernel_size: int = 13,
|
||||
sampling_rate: int = 44100,
|
||||
n_fft: int = 2048,
|
||||
win_length: int = 2048,
|
||||
hop_length: int = 512,
|
||||
f_min: int = 40,
|
||||
f_max: int = 16000,
|
||||
n_mels: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.backbone = ConvNeXtEncoder(
|
||||
input_channels=input_channels,
|
||||
depths=depths,
|
||||
dims=dims,
|
||||
drop_path_rate=drop_path_rate,
|
||||
kernel_sizes=kernel_sizes,
|
||||
)
|
||||
|
||||
self.head = HiFiGANGenerator(
|
||||
hop_length=hop_length,
|
||||
upsample_rates=upsample_rates,
|
||||
upsample_kernel_sizes=upsample_kernel_sizes,
|
||||
resblock_kernel_sizes=resblock_kernel_sizes,
|
||||
resblock_dilation_sizes=resblock_dilation_sizes,
|
||||
num_mels=num_mels,
|
||||
upsample_initial_channel=upsample_initial_channel,
|
||||
use_template=use_template,
|
||||
pre_conv_kernel_size=pre_conv_kernel_size,
|
||||
post_conv_kernel_size=post_conv_kernel_size,
|
||||
)
|
||||
self.sampling_rate = sampling_rate
|
||||
self.mel_transform = LogMelSpectrogram(
|
||||
sample_rate=sampling_rate,
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
f_min=f_min,
|
||||
f_max=f_max,
|
||||
n_mels=n_mels,
|
||||
)
|
||||
self.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, mel):
|
||||
y = self.backbone(mel)
|
||||
y = self.head(y)
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, x):
|
||||
return self.mel_transform(x)
|
||||
|
||||
def forward(self, mel):
|
||||
y = self.backbone(mel)
|
||||
y = self.head(y)
|
||||
return y
|
||||
@@ -75,16 +75,10 @@ class SnakeBeta(nn.Module):
|
||||
return x
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||
if activation == "elu":
|
||||
|
||||
@@ -228,6 +228,7 @@ class HunyuanVideo(nn.Module):
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
@@ -238,6 +239,14 @@ class HunyuanVideo(nn.Module):
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||
|
||||
if ref_latent is not None:
|
||||
ref_latent_ids = self.img_ids(ref_latent)
|
||||
ref_latent = self.img_in(ref_latent)
|
||||
img = torch.cat([ref_latent, img], dim=-2)
|
||||
ref_latent_ids[..., 0] = -1
|
||||
ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1])
|
||||
img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2)
|
||||
|
||||
if guiding_frame_index is not None:
|
||||
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
@@ -313,6 +322,8 @@ class HunyuanVideo(nn.Module):
|
||||
img[:, : img_len] += add
|
||||
|
||||
img = img[:, : img_len]
|
||||
if ref_latent is not None:
|
||||
img = img[:, ref_latent.shape[1]:]
|
||||
|
||||
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
@@ -324,7 +335,7 @@ class HunyuanVideo(nn.Module):
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
|
||||
def img_ids(self, x):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
@@ -334,7 +345,11 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
|
||||
@@ -247,6 +247,60 @@ class VaceWanAttentionBlock(WanAttentionBlock):
|
||||
return c_skip, c
|
||||
|
||||
|
||||
class WanCamAdapter(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}):
|
||||
super(WanCamAdapter, self).__init__()
|
||||
|
||||
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
|
||||
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
|
||||
|
||||
# Convolution: reduce spatial dimensions by a factor
|
||||
# of 2 (without overlap)
|
||||
self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
# Residual blocks for feature extraction
|
||||
self.residual_blocks = nn.Sequential(
|
||||
*[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Reshape to merge the frame dimension into batch
|
||||
bs, c, f, h, w = x.size()
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
|
||||
|
||||
# Pixel Unshuffle operation
|
||||
x_unshuffled = self.pixel_unshuffle(x)
|
||||
|
||||
# Convolution operation
|
||||
x_conv = self.conv(x_unshuffled)
|
||||
|
||||
# Feature extraction with residual blocks
|
||||
out = self.residual_blocks(x_conv)
|
||||
|
||||
# Reshape to restore original bf dimension
|
||||
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
|
||||
|
||||
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class WanCamResidualBlock(nn.Module):
|
||||
def __init__(self, dim, operation_settings={}):
|
||||
super(WanCamResidualBlock, self).__init__()
|
||||
self.conv1 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.relu(self.conv1(x))
|
||||
out = self.conv2(out)
|
||||
out += residual
|
||||
return out
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
|
||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
||||
@@ -637,3 +691,92 @@ class VaceWanModel(WanModel):
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
class CameraWanModel(WanModel):
|
||||
r"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='camera',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
flf_pos_embed_token_number=None,
|
||||
image_model=None,
|
||||
in_dim_control_adapter=24,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
|
||||
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
clip_fea=None,
|
||||
freqs=None,
|
||||
camera_conditions = None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if self.control_adapter is not None and camera_conditions is not None:
|
||||
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
|
||||
x = x + x_camera
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None:
|
||||
if self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
@@ -286,6 +286,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
|
||||
|
||||
if isinstance(model, comfy.model_base.ACEStep):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ import comfy.ldm.wan.model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.ace.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -923,6 +924,10 @@ class HunyuanVideo(BaseModel):
|
||||
if guiding_frame_index is not None:
|
||||
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
|
||||
|
||||
ref_latent = kwargs.get("ref_latent", None)
|
||||
if ref_latent is not None:
|
||||
out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent))
|
||||
|
||||
return out
|
||||
|
||||
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||
@@ -1074,6 +1079,17 @@ class WAN21_Vace(WAN21):
|
||||
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||
return out
|
||||
|
||||
class WAN21_Camera(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
camera_conditions = kwargs.get("camera_conditions", None)
|
||||
if camera_conditions is not None:
|
||||
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
@@ -1111,7 +1127,7 @@ class HiDream(BaseModel):
|
||||
return out
|
||||
|
||||
class Chroma(Flux):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
@@ -1121,3 +1137,22 @@ class Chroma(Flux):
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class ACEStep(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
||||
if cross_attn is not None:
|
||||
out['lyric_token_idx'] = comfy.conds.CONDRegular(conditioning_lyrics)
|
||||
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
|
||||
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
||||
return out
|
||||
|
||||
@@ -222,10 +222,39 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ltxv"
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
|
||||
dit_config["attention_head_dim"] = shape[0] // 32
|
||||
dit_config["cross_attention_dim"] = shape[1]
|
||||
if metadata is not None and "config" in metadata:
|
||||
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
||||
return dit_config
|
||||
|
||||
if '{}genre_embedder.weight'.format(key_prefix) in state_dict_keys: #ACE-Step model
|
||||
dit_config = {}
|
||||
dit_config["audio_model"] = "ace"
|
||||
dit_config["attention_head_dim"] = 128
|
||||
dit_config["in_channels"] = 8
|
||||
dit_config["inner_dim"] = 2560
|
||||
dit_config["max_height"] = 16
|
||||
dit_config["max_position"] = 32768
|
||||
dit_config["max_width"] = 32768
|
||||
dit_config["mlp_ratio"] = 2.5
|
||||
dit_config["num_attention_heads"] = 20
|
||||
dit_config["num_layers"] = 24
|
||||
dit_config["out_channels"] = 8
|
||||
dit_config["patch_size"] = [16, 1]
|
||||
dit_config["rope_theta"] = 1000000.0
|
||||
dit_config["speaker_embedding_dim"] = 512
|
||||
dit_config["text_embedding_dim"] = 768
|
||||
|
||||
dit_config["ssl_encoder_depths"] = [8, 8]
|
||||
dit_config["ssl_latent_dims"] = [1024, 768]
|
||||
dit_config["ssl_names"] = ["mert", "m-hubert"]
|
||||
dit_config["lyric_encoder_vocab_size"] = 6693
|
||||
dit_config["lyric_hidden_size"] = 1024
|
||||
return dit_config
|
||||
|
||||
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
|
||||
patch_size = 2
|
||||
dit_config = {}
|
||||
@@ -332,6 +361,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "vace"
|
||||
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "camera"
|
||||
else:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
|
||||
@@ -308,10 +308,10 @@ def fp8_linear(self, input):
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
|
||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
|
||||
@@ -30,7 +30,7 @@ if RMSNorm is None:
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape,
|
||||
eps=None,
|
||||
eps=1e-6,
|
||||
elementwise_affine=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
|
||||
60
comfy/sd.py
60
comfy/sd.py
@@ -15,6 +15,7 @@ import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import yaml
|
||||
import math
|
||||
|
||||
@@ -42,6 +43,7 @@ import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.hidream
|
||||
import comfy.text_encoders.ace
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -280,6 +282,7 @@ class VAE:
|
||||
|
||||
self.downscale_index_formula = None
|
||||
self.upscale_index_formula = None
|
||||
self.extra_1d_channel = None
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
@@ -437,6 +440,20 @@ class VAE:
|
||||
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
|
||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
||||
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
|
||||
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 8
|
||||
self.output_channels = 2
|
||||
self.upscale_ratio = 4096
|
||||
self.downscale_ratio = 4096
|
||||
self.latent_dim = 2
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.disable_offload = True
|
||||
self.extra_1d_channel = 16
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@@ -495,7 +512,13 @@ class VAE:
|
||||
return output
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
if samples.ndim == 3:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
else:
|
||||
og_shape = samples.shape
|
||||
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
|
||||
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
@@ -515,9 +538,24 @@ class VAE:
|
||||
samples /= 3.0
|
||||
return samples
|
||||
|
||||
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
||||
if self.latent_dim == 1:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
out_channels = self.latent_channels
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
else:
|
||||
extra_channel_size = self.extra_1d_channel
|
||||
out_channels = self.latent_channels * extra_channel_size
|
||||
tile_x = tile_x // extra_channel_size
|
||||
overlap = overlap // extra_channel_size
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
|
||||
|
||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||
if self.latent_dim == 1:
|
||||
return out
|
||||
else:
|
||||
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
|
||||
|
||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
@@ -542,7 +580,7 @@ class VAE:
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1:
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
@@ -609,7 +647,7 @@ class VAE:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1:
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
@@ -715,6 +753,7 @@ class CLIPType(Enum):
|
||||
WAN = 13
|
||||
HIDREAM = 14
|
||||
CHROMA = 15
|
||||
ACE = 16
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@@ -840,8 +879,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||
elif te_model == TEModel.T5_BASE:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
if clip_type == CLIPType.ACE or "spiece_model" in clip_data[0]:
|
||||
clip_target.clip = comfy.text_encoders.ace.AceT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.ace.AceT5Tokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
elif te_model == TEModel.GEMMA_2_2B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
|
||||
@@ -17,6 +17,7 @@ import comfy.text_encoders.hunyuan_video
|
||||
import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.ace
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -785,6 +786,10 @@ class LTXV(supported_models_base.BASE):
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("cross_attention_dim", 2048) / 2048) * 5.5
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXV(self, device=device)
|
||||
return out
|
||||
@@ -987,6 +992,16 @@ class WAN21_FunControl2V(WAN21_T2V):
|
||||
out = model_base.WAN21(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN21_Camera(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "camera",
|
||||
"in_dim": 32,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
|
||||
return out
|
||||
class WAN21_Vace(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@@ -1096,6 +1111,34 @@ class Chroma(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma]
|
||||
class ACEStep(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"audio_model": "ace",
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
latent_format = comfy.latent_formats.ACEAudio
|
||||
|
||||
memory_usage_factor = 0.5
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.ACEStep(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
153
comfy/text_encoders/ace.py
Normal file
153
comfy/text_encoders/ace.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from comfy import sd1_clip
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.t5
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from tokenizers import Tokenizer
|
||||
from .ace_text_cleaners import multilingual_cleaners, japanese_to_romaji
|
||||
|
||||
SUPPORT_LANGUAGES = {
|
||||
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
|
||||
"pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
|
||||
"nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
|
||||
"ko": 6152, "hi": 6680
|
||||
}
|
||||
|
||||
structure_pattern = re.compile(r"\[.*?\]")
|
||||
|
||||
DEFAULT_VOCAB_FILE = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
|
||||
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
|
||||
self.tokenizer = None
|
||||
if vocab_file is not None:
|
||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||
|
||||
def preprocess_text(self, txt, lang):
|
||||
txt = multilingual_cleaners(txt, lang)
|
||||
return txt
|
||||
|
||||
def encode(self, txt, lang='en'):
|
||||
# lang = lang.split("-")[0] # remove the region
|
||||
# self.check_input_length(txt, lang)
|
||||
txt = self.preprocess_text(txt, lang)
|
||||
lang = "zh-cn" if lang == "zh" else lang
|
||||
txt = f"[{lang}]{txt}"
|
||||
txt = txt.replace(" ", "[SPACE]")
|
||||
return self.tokenizer.encode(txt).ids
|
||||
|
||||
def get_lang(self, line):
|
||||
if line.startswith("[") and line[3:4] == ']':
|
||||
lang = line[1:3].lower()
|
||||
if lang in SUPPORT_LANGUAGES:
|
||||
return lang, line[4:]
|
||||
return "en", line
|
||||
|
||||
def __call__(self, string):
|
||||
lines = string.split("\n")
|
||||
lyric_token_idx = [261]
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
lyric_token_idx += [2]
|
||||
continue
|
||||
|
||||
lang, line = self.get_lang(line)
|
||||
|
||||
if lang not in SUPPORT_LANGUAGES:
|
||||
lang = "en"
|
||||
if "zh" in lang:
|
||||
lang = "zh"
|
||||
if "spa" in lang:
|
||||
lang = "es"
|
||||
|
||||
try:
|
||||
line_out = japanese_to_romaji(line)
|
||||
if line_out != line:
|
||||
lang = "ja"
|
||||
line = line_out
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if structure_pattern.match(line):
|
||||
token_idx = self.encode(line, "en")
|
||||
else:
|
||||
token_idx = self.encode(line, lang)
|
||||
lyric_token_idx = lyric_token_idx + token_idx + [2]
|
||||
except Exception as e:
|
||||
logging.warning("tokenize error {} for line {} major_language {}".format(e, line, lang))
|
||||
return {"input_ids": lyric_token_idx}
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path, **kwargs):
|
||||
return VoiceBpeTokenizer(path, **kwargs)
|
||||
|
||||
def get_vocab(self):
|
||||
return {}
|
||||
|
||||
|
||||
class UMT5BaseModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_base.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)
|
||||
|
||||
class UMT5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=768, embedding_key='umt5base', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
class LyricsTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='lyrics', tokenizer_class=VoiceBpeTokenizer, has_start_token=True, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=2, has_end_token=False, tokenizer_data=tokenizer_data)
|
||||
|
||||
class AceT5Tokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.voicebpe = LyricsTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.umt5base = UMT5BaseTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["lyrics"] = self.voicebpe.tokenize_with_weights(kwargs.get("lyrics", ""), return_word_ids, **kwargs)
|
||||
out["umt5base"] = self.umt5base.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.umt5base.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return self.umt5base.state_dict()
|
||||
|
||||
class AceT5Model(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__()
|
||||
self.umt5base = UMT5BaseModel(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes = set()
|
||||
if dtype is not None:
|
||||
self.dtypes.add(dtype)
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.umt5base.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.umt5base.reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_umt5base = token_weight_pairs["umt5base"]
|
||||
token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
|
||||
|
||||
t5_out, t5_pooled = self.umt5base.encode_token_weights(token_weight_pairs_umt5base)
|
||||
|
||||
lyrics_embeds = torch.tensor(list(map(lambda a: a[0], token_weight_pairs_lyrics[0]))).unsqueeze(0)
|
||||
return t5_out, None, {"conditioning_lyrics": lyrics_embeds}
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.umt5base.load_sd(sd)
|
||||
15535
comfy/text_encoders/ace_lyrics_tokenizer/vocab.json
Normal file
15535
comfy/text_encoders/ace_lyrics_tokenizer/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
395
comfy/text_encoders/ace_text_cleaners.py
Normal file
395
comfy/text_encoders/ace_text_cleaners.py
Normal file
@@ -0,0 +1,395 @@
|
||||
# basic text cleaners for the ACE step model
|
||||
# I didn't copy the ones from the reference code because I didn't want to deal with the dependencies
|
||||
# TODO: more languages than english?
|
||||
|
||||
import re
|
||||
|
||||
def japanese_to_romaji(japanese_text):
|
||||
"""
|
||||
Convert Japanese hiragana and katakana to romaji (Latin alphabet representation).
|
||||
|
||||
Args:
|
||||
japanese_text (str): Text containing hiragana and/or katakana characters
|
||||
|
||||
Returns:
|
||||
str: The romaji (Latin alphabet) equivalent
|
||||
"""
|
||||
# Dictionary mapping kana characters to their romaji equivalents
|
||||
kana_map = {
|
||||
# Katakana characters
|
||||
'ア': 'a', 'イ': 'i', 'ウ': 'u', 'エ': 'e', 'オ': 'o',
|
||||
'カ': 'ka', 'キ': 'ki', 'ク': 'ku', 'ケ': 'ke', 'コ': 'ko',
|
||||
'サ': 'sa', 'シ': 'shi', 'ス': 'su', 'セ': 'se', 'ソ': 'so',
|
||||
'タ': 'ta', 'チ': 'chi', 'ツ': 'tsu', 'テ': 'te', 'ト': 'to',
|
||||
'ナ': 'na', 'ニ': 'ni', 'ヌ': 'nu', 'ネ': 'ne', 'ノ': 'no',
|
||||
'ハ': 'ha', 'ヒ': 'hi', 'フ': 'fu', 'ヘ': 'he', 'ホ': 'ho',
|
||||
'マ': 'ma', 'ミ': 'mi', 'ム': 'mu', 'メ': 'me', 'モ': 'mo',
|
||||
'ヤ': 'ya', 'ユ': 'yu', 'ヨ': 'yo',
|
||||
'ラ': 'ra', 'リ': 'ri', 'ル': 'ru', 'レ': 're', 'ロ': 'ro',
|
||||
'ワ': 'wa', 'ヲ': 'wo', 'ン': 'n',
|
||||
|
||||
# Katakana voiced consonants
|
||||
'ガ': 'ga', 'ギ': 'gi', 'グ': 'gu', 'ゲ': 'ge', 'ゴ': 'go',
|
||||
'ザ': 'za', 'ジ': 'ji', 'ズ': 'zu', 'ゼ': 'ze', 'ゾ': 'zo',
|
||||
'ダ': 'da', 'ヂ': 'ji', 'ヅ': 'zu', 'デ': 'de', 'ド': 'do',
|
||||
'バ': 'ba', 'ビ': 'bi', 'ブ': 'bu', 'ベ': 'be', 'ボ': 'bo',
|
||||
'パ': 'pa', 'ピ': 'pi', 'プ': 'pu', 'ペ': 'pe', 'ポ': 'po',
|
||||
|
||||
# Katakana combinations
|
||||
'キャ': 'kya', 'キュ': 'kyu', 'キョ': 'kyo',
|
||||
'シャ': 'sha', 'シュ': 'shu', 'ショ': 'sho',
|
||||
'チャ': 'cha', 'チュ': 'chu', 'チョ': 'cho',
|
||||
'ニャ': 'nya', 'ニュ': 'nyu', 'ニョ': 'nyo',
|
||||
'ヒャ': 'hya', 'ヒュ': 'hyu', 'ヒョ': 'hyo',
|
||||
'ミャ': 'mya', 'ミュ': 'myu', 'ミョ': 'myo',
|
||||
'リャ': 'rya', 'リュ': 'ryu', 'リョ': 'ryo',
|
||||
'ギャ': 'gya', 'ギュ': 'gyu', 'ギョ': 'gyo',
|
||||
'ジャ': 'ja', 'ジュ': 'ju', 'ジョ': 'jo',
|
||||
'ビャ': 'bya', 'ビュ': 'byu', 'ビョ': 'byo',
|
||||
'ピャ': 'pya', 'ピュ': 'pyu', 'ピョ': 'pyo',
|
||||
|
||||
# Katakana small characters and special cases
|
||||
'ッ': '', # Small tsu (doubles the following consonant)
|
||||
'ャ': 'ya', 'ュ': 'yu', 'ョ': 'yo',
|
||||
|
||||
# Katakana extras
|
||||
'ヴ': 'vu', 'ファ': 'fa', 'フィ': 'fi', 'フェ': 'fe', 'フォ': 'fo',
|
||||
'ウィ': 'wi', 'ウェ': 'we', 'ウォ': 'wo',
|
||||
|
||||
# Hiragana characters
|
||||
'あ': 'a', 'い': 'i', 'う': 'u', 'え': 'e', 'お': 'o',
|
||||
'か': 'ka', 'き': 'ki', 'く': 'ku', 'け': 'ke', 'こ': 'ko',
|
||||
'さ': 'sa', 'し': 'shi', 'す': 'su', 'せ': 'se', 'そ': 'so',
|
||||
'た': 'ta', 'ち': 'chi', 'つ': 'tsu', 'て': 'te', 'と': 'to',
|
||||
'な': 'na', 'に': 'ni', 'ぬ': 'nu', 'ね': 'ne', 'の': 'no',
|
||||
'は': 'ha', 'ひ': 'hi', 'ふ': 'fu', 'へ': 'he', 'ほ': 'ho',
|
||||
'ま': 'ma', 'み': 'mi', 'む': 'mu', 'め': 'me', 'も': 'mo',
|
||||
'や': 'ya', 'ゆ': 'yu', 'よ': 'yo',
|
||||
'ら': 'ra', 'り': 'ri', 'る': 'ru', 'れ': 're', 'ろ': 'ro',
|
||||
'わ': 'wa', 'を': 'wo', 'ん': 'n',
|
||||
|
||||
# Hiragana voiced consonants
|
||||
'が': 'ga', 'ぎ': 'gi', 'ぐ': 'gu', 'げ': 'ge', 'ご': 'go',
|
||||
'ざ': 'za', 'じ': 'ji', 'ず': 'zu', 'ぜ': 'ze', 'ぞ': 'zo',
|
||||
'だ': 'da', 'ぢ': 'ji', 'づ': 'zu', 'で': 'de', 'ど': 'do',
|
||||
'ば': 'ba', 'び': 'bi', 'ぶ': 'bu', 'べ': 'be', 'ぼ': 'bo',
|
||||
'ぱ': 'pa', 'ぴ': 'pi', 'ぷ': 'pu', 'ぺ': 'pe', 'ぽ': 'po',
|
||||
|
||||
# Hiragana combinations
|
||||
'きゃ': 'kya', 'きゅ': 'kyu', 'きょ': 'kyo',
|
||||
'しゃ': 'sha', 'しゅ': 'shu', 'しょ': 'sho',
|
||||
'ちゃ': 'cha', 'ちゅ': 'chu', 'ちょ': 'cho',
|
||||
'にゃ': 'nya', 'にゅ': 'nyu', 'にょ': 'nyo',
|
||||
'ひゃ': 'hya', 'ひゅ': 'hyu', 'ひょ': 'hyo',
|
||||
'みゃ': 'mya', 'みゅ': 'myu', 'みょ': 'myo',
|
||||
'りゃ': 'rya', 'りゅ': 'ryu', 'りょ': 'ryo',
|
||||
'ぎゃ': 'gya', 'ぎゅ': 'gyu', 'ぎょ': 'gyo',
|
||||
'じゃ': 'ja', 'じゅ': 'ju', 'じょ': 'jo',
|
||||
'びゃ': 'bya', 'びゅ': 'byu', 'びょ': 'byo',
|
||||
'ぴゃ': 'pya', 'ぴゅ': 'pyu', 'ぴょ': 'pyo',
|
||||
|
||||
# Hiragana small characters and special cases
|
||||
'っ': '', # Small tsu (doubles the following consonant)
|
||||
'ゃ': 'ya', 'ゅ': 'yu', 'ょ': 'yo',
|
||||
|
||||
# Common punctuation and spaces
|
||||
' ': ' ', # Japanese space
|
||||
'、': ', ', '。': '. ',
|
||||
}
|
||||
|
||||
result = []
|
||||
i = 0
|
||||
|
||||
while i < len(japanese_text):
|
||||
# Check for small tsu (doubling the following consonant)
|
||||
if i < len(japanese_text) - 1 and (japanese_text[i] == 'っ' or japanese_text[i] == 'ッ'):
|
||||
if i < len(japanese_text) - 1 and japanese_text[i+1] in kana_map:
|
||||
next_romaji = kana_map[japanese_text[i+1]]
|
||||
if next_romaji and next_romaji[0] not in 'aiueon':
|
||||
result.append(next_romaji[0]) # Double the consonant
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Check for combinations with small ya, yu, yo
|
||||
if i < len(japanese_text) - 1 and japanese_text[i+1] in ('ゃ', 'ゅ', 'ょ', 'ャ', 'ュ', 'ョ'):
|
||||
combo = japanese_text[i:i+2]
|
||||
if combo in kana_map:
|
||||
result.append(kana_map[combo])
|
||||
i += 2
|
||||
continue
|
||||
|
||||
# Regular character
|
||||
if japanese_text[i] in kana_map:
|
||||
result.append(kana_map[japanese_text[i]])
|
||||
else:
|
||||
# If it's not in our map, keep it as is (might be kanji, romaji, etc.)
|
||||
result.append(japanese_text[i])
|
||||
|
||||
i += 1
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
def number_to_text(num, ordinal=False):
|
||||
"""
|
||||
Convert a number (int or float) to its text representation.
|
||||
|
||||
Args:
|
||||
num: The number to convert
|
||||
|
||||
Returns:
|
||||
str: Text representation of the number
|
||||
"""
|
||||
|
||||
if not isinstance(num, (int, float)):
|
||||
return "Input must be a number"
|
||||
|
||||
# Handle special case of zero
|
||||
if num == 0:
|
||||
return "zero"
|
||||
|
||||
# Handle negative numbers
|
||||
negative = num < 0
|
||||
num = abs(num)
|
||||
|
||||
# Handle floats
|
||||
if isinstance(num, float):
|
||||
# Split into integer and decimal parts
|
||||
int_part = int(num)
|
||||
|
||||
# Convert both parts
|
||||
int_text = _int_to_text(int_part)
|
||||
|
||||
# Handle decimal part (convert to string and remove '0.')
|
||||
decimal_str = str(num).split('.')[1]
|
||||
decimal_text = " point " + " ".join(_digit_to_text(int(digit)) for digit in decimal_str)
|
||||
|
||||
result = int_text + decimal_text
|
||||
else:
|
||||
# Handle integers
|
||||
result = _int_to_text(num)
|
||||
|
||||
# Add 'negative' prefix for negative numbers
|
||||
if negative:
|
||||
result = "negative " + result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _int_to_text(num):
|
||||
"""Helper function to convert an integer to text"""
|
||||
|
||||
ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
|
||||
"ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen",
|
||||
"seventeen", "eighteen", "nineteen"]
|
||||
|
||||
tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
|
||||
|
||||
if num < 20:
|
||||
return ones[num]
|
||||
|
||||
if num < 100:
|
||||
return tens[num // 10] + (" " + ones[num % 10] if num % 10 != 0 else "")
|
||||
|
||||
if num < 1000:
|
||||
return ones[num // 100] + " hundred" + (" " + _int_to_text(num % 100) if num % 100 != 0 else "")
|
||||
|
||||
if num < 1000000:
|
||||
return _int_to_text(num // 1000) + " thousand" + (" " + _int_to_text(num % 1000) if num % 1000 != 0 else "")
|
||||
|
||||
if num < 1000000000:
|
||||
return _int_to_text(num // 1000000) + " million" + (" " + _int_to_text(num % 1000000) if num % 1000000 != 0 else "")
|
||||
|
||||
return _int_to_text(num // 1000000000) + " billion" + (" " + _int_to_text(num % 1000000000) if num % 1000000000 != 0 else "")
|
||||
|
||||
|
||||
def _digit_to_text(digit):
|
||||
"""Convert a single digit to text"""
|
||||
digits = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
||||
return digits[digit]
|
||||
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = {
|
||||
"en": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
("dr", "doctor"),
|
||||
("st", "saint"),
|
||||
("co", "company"),
|
||||
("jr", "junior"),
|
||||
("maj", "major"),
|
||||
("gen", "general"),
|
||||
("drs", "doctors"),
|
||||
("rev", "reverend"),
|
||||
("lt", "lieutenant"),
|
||||
("hon", "honorable"),
|
||||
("sgt", "sergeant"),
|
||||
("capt", "captain"),
|
||||
("esq", "esquire"),
|
||||
("ltd", "limited"),
|
||||
("col", "colonel"),
|
||||
("ft", "fort"),
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def expand_abbreviations_multilingual(text, lang="en"):
|
||||
for regex, replacement in _abbreviations[lang]:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
_symbols_multilingual = {
|
||||
"en": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " and "),
|
||||
("@", " at "),
|
||||
("%", " percent "),
|
||||
("#", " hash "),
|
||||
("$", " dollar "),
|
||||
("£", " pound "),
|
||||
("°", " degree "),
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def expand_symbols_multilingual(text, lang="en"):
|
||||
for regex, replacement in _symbols_multilingual[lang]:
|
||||
text = re.sub(regex, replacement, text)
|
||||
text = text.replace(" ", " ") # Ensure there are no double spaces
|
||||
return text.strip()
|
||||
|
||||
|
||||
_ordinal_re = {
|
||||
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
|
||||
}
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
_currency_re = {
|
||||
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
|
||||
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
|
||||
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
|
||||
}
|
||||
|
||||
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
|
||||
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
|
||||
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
text = m.group(0)
|
||||
if "," in text:
|
||||
text = text.replace(",", "")
|
||||
return text
|
||||
|
||||
|
||||
def _remove_dots(m):
|
||||
text = m.group(0)
|
||||
if "." in text:
|
||||
text = text.replace(".", "")
|
||||
return text
|
||||
|
||||
|
||||
def _expand_decimal_point(m, lang="en"):
|
||||
amount = m.group(1).replace(",", ".")
|
||||
return number_to_text(float(amount))
|
||||
|
||||
|
||||
def _expand_currency(m, lang="en", currency="USD"):
|
||||
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
|
||||
full_amount = number_to_text(amount)
|
||||
|
||||
and_equivalents = {
|
||||
"en": ", ",
|
||||
"es": " con ",
|
||||
"fr": " et ",
|
||||
"de": " und ",
|
||||
"pt": " e ",
|
||||
"it": " e ",
|
||||
"pl": ", ",
|
||||
"cs": ", ",
|
||||
"ru": ", ",
|
||||
"nl": ", ",
|
||||
"ar": ", ",
|
||||
"tr": ", ",
|
||||
"hu": ", ",
|
||||
"ko": ", ",
|
||||
}
|
||||
|
||||
if amount.is_integer():
|
||||
last_and = full_amount.rfind(and_equivalents[lang])
|
||||
if last_and != -1:
|
||||
full_amount = full_amount[:last_and]
|
||||
|
||||
return full_amount
|
||||
|
||||
|
||||
def _expand_ordinal(m, lang="en"):
|
||||
return number_to_text(int(m.group(1)), ordinal=True)
|
||||
|
||||
|
||||
def _expand_number(m, lang="en"):
|
||||
return number_to_text(int(m.group(0)))
|
||||
|
||||
|
||||
def expand_numbers_multilingual(text, lang="en"):
|
||||
if lang in ["en", "ru"]:
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
else:
|
||||
text = re.sub(_dot_number_re, _remove_dots, text)
|
||||
try:
|
||||
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
|
||||
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
|
||||
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
|
||||
except:
|
||||
pass
|
||||
|
||||
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
|
||||
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
|
||||
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
|
||||
return text
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def multilingual_cleaners(text, lang):
|
||||
text = text.replace('"', "")
|
||||
if lang == "tr":
|
||||
text = text.replace("İ", "i")
|
||||
text = text.replace("Ö", "ö")
|
||||
text = text.replace("Ü", "ü")
|
||||
text = lowercase(text)
|
||||
try:
|
||||
text = expand_numbers_multilingual(text, lang)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
text = expand_abbreviations_multilingual(text, lang)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
text = expand_symbols_multilingual(text, lang=lang)
|
||||
except:
|
||||
pass
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
22
comfy/text_encoders/umt5_config_base.json
Normal file
22
comfy/text_encoders/umt5_config_base.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"d_ff": 2048,
|
||||
"d_kv": 64,
|
||||
"d_model": 768,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"dense_act_fn": "gelu_pytorch_tanh",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "umt5",
|
||||
"num_decoder_layers": 12,
|
||||
"num_heads": 12,
|
||||
"num_layers": 12,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"vocab_size": 256384
|
||||
}
|
||||
@@ -28,6 +28,9 @@ import logging
|
||||
import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
|
||||
ALWAYS_SAFE_LOAD = False
|
||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||
@@ -67,12 +70,14 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
|
||||
raise e
|
||||
else:
|
||||
torch_args = {}
|
||||
if MMAP_TORCH_FILES:
|
||||
torch_args["mmap"] = True
|
||||
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||
if "global_step" in pl_sd:
|
||||
logging.debug(f"Global Step: {pl_sd['global_step']}")
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
|
||||
@@ -43,3 +43,13 @@ class VideoInput(ABC):
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
components = self.get_components()
|
||||
frame_count = components.images.shape[0]
|
||||
return float(frame_count / components.frame_rate)
|
||||
|
||||
@@ -80,6 +80,38 @@ class VideoFromFile(VideoInput):
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
if container.duration is not None:
|
||||
return float(container.duration / av.time_base)
|
||||
|
||||
# Fallback: calculate from frame count and frame rate
|
||||
video_stream = next(
|
||||
(s for s in container.streams if s.type == "video"), None
|
||||
)
|
||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||
return float(video_stream.frames / video_stream.average_rate)
|
||||
|
||||
# Last resort: decode frames to count them
|
||||
if video_stream and video_stream.average_rate:
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
# Get video frames
|
||||
frames = []
|
||||
|
||||
41
comfy_api_nodes/README.md
Normal file
41
comfy_api_nodes/README.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# ComfyUI API Nodes
|
||||
|
||||
## Introduction
|
||||
|
||||
Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview#api-nodes).
|
||||
|
||||
## Development
|
||||
|
||||
While developing, you should be testing against the Staging environment. To test against staging:
|
||||
|
||||
**Install ComfyUI_frontend**
|
||||
|
||||
Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to start the frontend server. By default, it will connect to Staging authentication.
|
||||
|
||||
> **Hint:** If you use --front-end-version argument for ComfyUI, it will use production authentication.
|
||||
|
||||
```bash
|
||||
python run main.py --comfy-api-base https://stagingapi.comfy.org
|
||||
```
|
||||
|
||||
API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
|
||||
|
||||
### Redocly Instructions
|
||||
|
||||
**Tip**
|
||||
When developing locally, use the `redocly-dev.yaml` file to generate pydantic models. This lets you use stubs for APIs that are not marked `Released` yet.
|
||||
|
||||
Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
|
||||
|
||||
```bash
|
||||
# Download the OpenAPI file from prod server.
|
||||
curl -o openapi.yaml https://stagingapi.comfy.org/openapi
|
||||
|
||||
# Filter out unneeded API definitions.
|
||||
npm install -g @redocly/cli
|
||||
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly-dev.yaml --remove-unused-components
|
||||
|
||||
# Generate the pydantic datamodels for validation.
|
||||
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
|
||||
|
||||
```
|
||||
583
comfy_api_nodes/apinode_utils.py
Normal file
583
comfy_api_nodes/apinode_utils.py
Normal file
@@ -0,0 +1,583 @@
|
||||
from __future__ import annotations
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.util import VideoContainer, VideoCodec
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.input.basic_types import AudioInput
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiClient,
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
UploadRequest,
|
||||
UploadResponse,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
import math
|
||||
import base64
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
import av
|
||||
|
||||
|
||||
def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
|
||||
"""Downloads a video from a URL and returns a `VIDEO` output.
|
||||
|
||||
Args:
|
||||
video_url: The URL of the video to download.
|
||||
|
||||
Returns:
|
||||
A Comfy node `VIDEO` output.
|
||||
"""
|
||||
video_io = download_url_to_bytesio(video_url, timeout)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {video_url}"
|
||||
logging.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
return VideoFromFile(video_io)
|
||||
|
||||
|
||||
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(total_pixels)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
if scale_by >= 1:
|
||||
return image
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||
s = s.movedim(1, -1)
|
||||
return s
|
||||
|
||||
|
||||
def validate_and_cast_response(
|
||||
response, timeout: int = None, node_id: Union[str, None] = None
|
||||
) -> torch.Tensor:
|
||||
"""Validates and casts a response to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
response: The response to validate and cast.
|
||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
|
||||
Raises:
|
||||
ValueError: If the response is not valid.
|
||||
"""
|
||||
# validate raw JSON response
|
||||
data = response.data
|
||||
if not data or len(data) == 0:
|
||||
raise ValueError("No images returned from API endpoint")
|
||||
|
||||
# Initialize list to store image tensors
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
|
||||
# Process each image in the data array
|
||||
for image_data in data:
|
||||
image_url = image_data.url
|
||||
b64_data = image_data.b64_json
|
||||
|
||||
if not image_url and not b64_data:
|
||||
raise ValueError("No image was generated in the response")
|
||||
|
||||
if b64_data:
|
||||
img_data = base64.b64decode(b64_data)
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
|
||||
elif image_url:
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {image_url}", node_id
|
||||
)
|
||||
img_response = requests.get(image_url, timeout=timeout)
|
||||
if img_response.status_code != 200:
|
||||
raise ValueError("Failed to download the image")
|
||||
img = Image.open(io.BytesIO(img_response.content))
|
||||
|
||||
img = img.convert("RGBA")
|
||||
|
||||
# Convert to numpy array, normalize to float32 between 0 and 1
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
img_tensor = torch.from_numpy(img_array)
|
||||
|
||||
# Add to list of tensors
|
||||
image_tensors.append(img_tensor)
|
||||
|
||||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
|
||||
def validate_aspect_ratio(
|
||||
aspect_ratio: str,
|
||||
minimum_ratio: float,
|
||||
maximum_ratio: float,
|
||||
minimum_ratio_str: str,
|
||||
maximum_ratio_str: str,
|
||||
) -> float:
|
||||
"""Validates and casts an aspect ratio string to a float.
|
||||
|
||||
Args:
|
||||
aspect_ratio: The aspect ratio string to validate.
|
||||
minimum_ratio: The minimum aspect ratio.
|
||||
maximum_ratio: The maximum aspect ratio.
|
||||
minimum_ratio_str: The minimum aspect ratio string.
|
||||
maximum_ratio_str: The maximum aspect ratio string.
|
||||
|
||||
Returns:
|
||||
The validated and cast aspect ratio.
|
||||
|
||||
Raises:
|
||||
Exception: If the aspect ratio is not valid.
|
||||
"""
|
||||
# get ratio values
|
||||
numbers = aspect_ratio.split(":")
|
||||
if len(numbers) != 2:
|
||||
raise TypeError(
|
||||
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
|
||||
)
|
||||
try:
|
||||
numerator = int(numbers[0])
|
||||
denominator = int(numbers[1])
|
||||
except ValueError as exc:
|
||||
raise TypeError(
|
||||
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
|
||||
) from exc
|
||||
calculated_ratio = numerator / denominator
|
||||
# if not close to minimum and maximum, check bounds
|
||||
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
|
||||
calculated_ratio, maximum_ratio
|
||||
):
|
||||
if calculated_ratio < minimum_ratio:
|
||||
raise TypeError(
|
||||
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
||||
)
|
||||
elif calculated_ratio > maximum_ratio:
|
||||
raise TypeError(
|
||||
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
||||
)
|
||||
return aspect_ratio
|
||||
|
||||
|
||||
def mimetype_to_extension(mime_type: str) -> str:
|
||||
"""Converts a MIME type to a file extension."""
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
|
||||
def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
|
||||
"""Downloads content from a URL using requests and returns it as BytesIO.
|
||||
|
||||
Args:
|
||||
url: The URL to download.
|
||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||
|
||||
Returns:
|
||||
BytesIO object containing the downloaded content.
|
||||
"""
|
||||
response = requests.get(url, stream=True, timeout=timeout)
|
||||
response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
||||
return BytesIO(response.content)
|
||||
|
||||
|
||||
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
||||
"""Converts image data from BytesIO to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
image_bytesio: BytesIO object containing the image data.
|
||||
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
|
||||
Raises:
|
||||
PIL.UnidentifiedImageError: If the image data cannot be identified.
|
||||
ValueError: If the specified mode is invalid.
|
||||
"""
|
||||
image = Image.open(image_bytesio)
|
||||
image = image.convert(mode)
|
||||
image_array = np.array(image).astype(np.float32) / 255.0
|
||||
return torch.from_numpy(image_array).unsqueeze(0)
|
||||
|
||||
|
||||
def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
||||
image_bytesio = download_url_to_bytesio(url, timeout)
|
||||
return bytesio_to_image_tensor(image_bytesio)
|
||||
|
||||
def process_image_response(response: requests.Response) -> torch.Tensor:
|
||||
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
||||
return bytesio_to_image_tensor(BytesIO(response.content))
|
||||
|
||||
|
||||
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
|
||||
if len(image.shape) > 3:
|
||||
image = image[0]
|
||||
# TODO: remove alpha if not allowed and present
|
||||
input_tensor = image.cpu()
|
||||
input_tensor = downscale_image_tensor(
|
||||
input_tensor.unsqueeze(0), total_pixels=total_pixels
|
||||
).squeeze()
|
||||
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
return img
|
||||
|
||||
|
||||
def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
||||
"""Converts a PIL Image to a BytesIO object."""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
|
||||
pil_format = mime_type.split("/")[-1].upper()
|
||||
if pil_format == "JPG":
|
||||
pil_format = "JPEG"
|
||||
img.save(img_byte_arr, format=pil_format)
|
||||
img_byte_arr.seek(0)
|
||||
return img_byte_arr
|
||||
|
||||
|
||||
def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
name: Optional[str] = None,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
name: Optional filename for the BytesIO object.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Named BytesIO object containing the image data.
|
||||
"""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
pil_image = _tensor_to_pil(image, total_pixels=total_pixels)
|
||||
img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_binary.name = (
|
||||
f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||
)
|
||||
return img_binary
|
||||
|
||||
|
||||
def tensor_to_base64_string(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Base64 encoded string of the image.
|
||||
"""
|
||||
pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels)
|
||||
img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
# Encode bytes to base64 string
|
||||
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
|
||||
return base64_encoded_string
|
||||
|
||||
|
||||
def tensor_to_data_uri(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Converts a tensor image to a Data URI string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
|
||||
|
||||
Returns:
|
||||
Data URI string (e.g., 'data:image/png;base64,...').
|
||||
"""
|
||||
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def upload_file_to_comfyapi(
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: str,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single file to ComfyUI API and returns its download URL.
|
||||
|
||||
Args:
|
||||
file_bytes_io: BytesIO object containing the file data.
|
||||
filename: The filename of the file.
|
||||
upload_mime_type: MIME type of the file.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded file.
|
||||
"""
|
||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/customers/storage",
|
||||
method=HttpMethod.POST,
|
||||
request_model=UploadRequest,
|
||||
response_model=UploadResponse,
|
||||
),
|
||||
request=request_object,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
response: UploadResponse = operation.execute()
|
||||
upload_response = ApiClient.upload_file(
|
||||
response.upload_url, file_bytes_io, content_type=upload_mime_type
|
||||
)
|
||||
upload_response.raise_for_status()
|
||||
|
||||
return response.download_url
|
||||
|
||||
|
||||
def upload_video_to_comfyapi(
|
||||
video: VideoInput,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
codec: VideoCodec = VideoCodec.H264,
|
||||
max_duration: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single video to ComfyUI API and returns its download URL.
|
||||
Uses the specified container and codec for saving the video before upload.
|
||||
|
||||
Args:
|
||||
video: VideoInput object (Comfy VIDEO type).
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
container: The video container format to use (default: MP4).
|
||||
codec: The video codec to use (default: H264).
|
||||
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded video file.
|
||||
"""
|
||||
if max_duration is not None:
|
||||
try:
|
||||
actual_duration = video.duration_seconds
|
||||
if actual_duration is not None and actual_duration > max_duration:
|
||||
raise ValueError(
|
||||
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting video duration: {e}")
|
||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||
|
||||
upload_mime_type = f"video/{container.value.lower()}"
|
||||
filename = f"uploaded_video.{container.value.lower()}"
|
||||
|
||||
# Convert VideoInput to BytesIO using specified container/codec
|
||||
video_bytes_io = io.BytesIO()
|
||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return upload_file_to_comfyapi(
|
||||
video_bytes_io, filename, upload_mime_type, auth_kwargs
|
||||
)
|
||||
|
||||
|
||||
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
||||
"""
|
||||
Prepares audio waveform for av library by converting to a contiguous numpy array.
|
||||
|
||||
Args:
|
||||
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
|
||||
|
||||
Returns:
|
||||
Contiguous numpy array of the audio waveform. If the audio was batched,
|
||||
the first item is taken.
|
||||
"""
|
||||
if waveform.ndim != 3 or waveform.shape[0] != 1:
|
||||
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
|
||||
|
||||
# If batch is > 1, take first item
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform[0]
|
||||
|
||||
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
|
||||
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
|
||||
if audio_data_np.dtype != np.float32:
|
||||
audio_data_np = audio_data_np.astype(np.float32)
|
||||
|
||||
return audio_data_np
|
||||
|
||||
|
||||
def audio_ndarray_to_bytesio(
|
||||
audio_data_np: np.ndarray,
|
||||
sample_rate: int,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
) -> BytesIO:
|
||||
"""
|
||||
Encodes a numpy array of audio data into a BytesIO object.
|
||||
"""
|
||||
audio_bytes_io = io.BytesIO()
|
||||
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
|
||||
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
audio_data_np,
|
||||
format="fltp",
|
||||
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
|
||||
# Flush stream
|
||||
for packet in audio_stream.encode(None):
|
||||
output_container.mux(packet)
|
||||
|
||||
audio_bytes_io.seek(0)
|
||||
return audio_bytes_io
|
||||
|
||||
|
||||
def upload_audio_to_comfyapi(
|
||||
audio: AudioInput,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
mime_type: str = "audio/mp4",
|
||||
filename: str = "uploaded_audio.mp4",
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single audio input to ComfyUI API and returns its download URL.
|
||||
Encodes the raw waveform into the specified format before uploading.
|
||||
|
||||
Args:
|
||||
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded audio file.
|
||||
"""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
|
||||
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def upload_images_to_comfyapi(
|
||||
image: torch.Tensor, max_images=8, auth_kwargs: Optional[dict[str,str]] = None, mime_type: Optional[str] = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
To upload multiple images, stack them in the batch dimension first.
|
||||
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
max_images: Maximum number of images to upload.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
mime_type: Optional MIME type for the image.
|
||||
"""
|
||||
# if batch, try to upload each file if max_images is greater than 0
|
||||
idx_image = 0
|
||||
download_urls: list[str] = []
|
||||
is_batch = len(image.shape) > 3
|
||||
batch_length = 1
|
||||
if is_batch:
|
||||
batch_length = image.shape[0]
|
||||
while True:
|
||||
curr_image = image
|
||||
if len(image.shape) > 3:
|
||||
curr_image = image[idx_image]
|
||||
# get BytesIO version of image
|
||||
img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type)
|
||||
# first, request upload/download urls from comfy API
|
||||
if not mime_type:
|
||||
request_object = UploadRequest(file_name=img_binary.name)
|
||||
else:
|
||||
request_object = UploadRequest(
|
||||
file_name=img_binary.name, content_type=mime_type
|
||||
)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/customers/storage",
|
||||
method=HttpMethod.POST,
|
||||
request_model=UploadRequest,
|
||||
response_model=UploadResponse,
|
||||
),
|
||||
request=request_object,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response = operation.execute()
|
||||
|
||||
upload_response = ApiClient.upload_file(
|
||||
response.upload_url, img_binary, content_type=mime_type
|
||||
)
|
||||
# verify success
|
||||
try:
|
||||
upload_response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise ValueError(f"Could not upload one or more images: {e}") from e
|
||||
# add download_url to list
|
||||
download_urls.append(response.download_url)
|
||||
|
||||
idx_image += 1
|
||||
# stop uploading additional files if done
|
||||
if is_batch and max_images > 0:
|
||||
if idx_image >= max_images:
|
||||
break
|
||||
if idx_image >= batch_length:
|
||||
break
|
||||
return download_urls
|
||||
|
||||
|
||||
def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor,
|
||||
upscale_method="nearest-exact", crop="disabled",
|
||||
allow_gradient=True, add_channel_dim=False):
|
||||
"""
|
||||
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
|
||||
"""
|
||||
_, H, W, _ = image.shape
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = mask.movedim(-1,1)
|
||||
mask = common_upscale(mask, width=W, height=H, upscale_method=upscale_method, crop=crop)
|
||||
mask = mask.movedim(1,-1)
|
||||
if not add_channel_dim:
|
||||
mask = mask.squeeze(-1)
|
||||
if not allow_gradient:
|
||||
mask = (mask > 0.5).float()
|
||||
return mask
|
||||
|
||||
|
||||
def validate_string(string: str, strip_whitespace=True, field_name="prompt", min_length=None, max_length=None):
|
||||
if strip_whitespace:
|
||||
string = string.strip()
|
||||
if min_length and len(string) < min_length:
|
||||
raise Exception(f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long.")
|
||||
if max_length and len(string) > max_length:
|
||||
raise Exception(f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long.")
|
||||
if not string:
|
||||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||||
@@ -1,6 +1,6 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: https://api.comfy.org/openapi
|
||||
# timestamp: 2025-04-23T15:56:33+00:00
|
||||
# filename: filtered-openapi.yaml
|
||||
# timestamp: 2025-04-29T23:44:54+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: https://api.comfy.org/openapi
|
||||
# timestamp: 2025-04-23T15:56:33+00:00
|
||||
# filename: filtered-openapi.yaml
|
||||
# timestamp: 2025-04-29T23:44:54+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, constr
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class V2OpenAPII2VResp(BaseModel):
|
||||
@@ -30,10 +30,10 @@ class V2OpenAPIT2VReq(BaseModel):
|
||||
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
|
||||
examples=['normal'],
|
||||
)
|
||||
negative_prompt: Optional[constr(max_length=2048)] = Field(
|
||||
None, description='Negative prompt\n'
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None, description='Negative prompt\n', max_length=2048
|
||||
)
|
||||
prompt: constr(max_length=2048) = Field(..., description='Prompt')
|
||||
prompt: str = Field(..., description='Prompt', max_length=2048)
|
||||
quality: str = Field(
|
||||
...,
|
||||
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
156
comfy_api_nodes/apis/bfl_api.py
Normal file
156
comfy_api_nodes/apis/bfl_api.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, confloat, conint
|
||||
|
||||
|
||||
class BFLOutputFormat(str, Enum):
|
||||
png = 'png'
|
||||
jpeg = 'jpeg'
|
||||
|
||||
|
||||
class BFLFluxExpandImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image')
|
||||
bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image')
|
||||
left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image')
|
||||
right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand')
|
||||
|
||||
|
||||
class BFLFluxFillImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.')
|
||||
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
|
||||
|
||||
|
||||
class BFLFluxCannyImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='Text prompt for image generation')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection')
|
||||
canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection')
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
||||
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
||||
|
||||
|
||||
class BFLFluxDepthImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='Text prompt for image generation')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
||||
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
||||
|
||||
|
||||
class BFLFluxProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.')
|
||||
height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
|
||||
# image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
|
||||
# None, description='Blend between the prompt and the image prompt.'
|
||||
# )
|
||||
|
||||
|
||||
class BFLFluxProUltraGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.')
|
||||
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
|
||||
image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
|
||||
None, description='Blend between the prompt and the image prompt.'
|
||||
)
|
||||
|
||||
|
||||
class BFLFluxProGenerateResponse(BaseModel):
|
||||
id: str = Field(..., description='The unique identifier for the generation task.')
|
||||
polling_url: str = Field(..., description='URL to poll for the generation result.')
|
||||
|
||||
|
||||
class BFLStatus(str, Enum):
|
||||
task_not_found = "Task not found"
|
||||
pending = "Pending"
|
||||
request_moderated = "Request Moderated"
|
||||
content_moderated = "Content Moderated"
|
||||
ready = "Ready"
|
||||
error = "Error"
|
||||
|
||||
|
||||
class BFLFluxProStatusResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||
status: BFLStatus = Field(..., description="The status of the task.")
|
||||
result: Optional[Dict[str, Any]] = Field(
|
||||
None, description="The result of the task (null if not completed)."
|
||||
)
|
||||
progress: confloat(ge=0.0, le=1.0) = Field(
|
||||
..., description="The progress of the task (0.0 to 1.0)."
|
||||
)
|
||||
details: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Additional details about the task (null if not available)."
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
253
comfy_api_nodes/apis/luma_api.py
Normal file
253
comfy_api_nodes/apis/luma_api.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, confloat
|
||||
|
||||
|
||||
|
||||
class LumaIO:
|
||||
LUMA_REF = "LUMA_REF"
|
||||
LUMA_CONCEPTS = "LUMA_CONCEPTS"
|
||||
|
||||
|
||||
class LumaReference:
|
||||
def __init__(self, image: torch.Tensor, weight: float):
|
||||
self.image = image
|
||||
self.weight = weight
|
||||
|
||||
def create_api_model(self, download_url: str):
|
||||
return LumaImageRef(url=download_url, weight=self.weight)
|
||||
|
||||
class LumaReferenceChain:
|
||||
def __init__(self, first_ref: LumaReference=None):
|
||||
self.refs: list[LumaReference] = []
|
||||
if first_ref:
|
||||
self.refs.append(first_ref)
|
||||
|
||||
def add(self, luma_ref: LumaReference=None):
|
||||
self.refs.append(luma_ref)
|
||||
|
||||
def create_api_model(self, download_urls: list[str], max_refs=4):
|
||||
if len(self.refs) == 0:
|
||||
return None
|
||||
api_refs: list[LumaImageRef] = []
|
||||
for ref, url in zip(self.refs, download_urls):
|
||||
api_ref = LumaImageRef(url=url, weight=ref.weight)
|
||||
api_refs.append(api_ref)
|
||||
return api_refs
|
||||
|
||||
def clone(self):
|
||||
c = LumaReferenceChain()
|
||||
for ref in self.refs:
|
||||
c.add(ref)
|
||||
return c
|
||||
|
||||
|
||||
class LumaConcept:
|
||||
def __init__(self, key: str):
|
||||
self.key = key
|
||||
|
||||
|
||||
class LumaConceptChain:
|
||||
def __init__(self, str_list: list[str] = None):
|
||||
self.concepts: list[LumaConcept] = []
|
||||
if str_list is not None:
|
||||
for c in str_list:
|
||||
if c != "None":
|
||||
self.add(LumaConcept(key=c))
|
||||
|
||||
def add(self, concept: LumaConcept):
|
||||
self.concepts.append(concept)
|
||||
|
||||
def create_api_model(self):
|
||||
if len(self.concepts) == 0:
|
||||
return None
|
||||
api_concepts: list[LumaConceptObject] = []
|
||||
for concept in self.concepts:
|
||||
if concept.key == "None":
|
||||
continue
|
||||
api_concepts.append(LumaConceptObject(key=concept.key))
|
||||
if len(api_concepts) == 0:
|
||||
return None
|
||||
return api_concepts
|
||||
|
||||
def clone(self):
|
||||
c = LumaConceptChain()
|
||||
for concept in self.concepts:
|
||||
c.add(concept)
|
||||
return c
|
||||
|
||||
def clone_and_merge(self, other: LumaConceptChain):
|
||||
c = self.clone()
|
||||
for concept in other.concepts:
|
||||
c.add(concept)
|
||||
return c
|
||||
|
||||
|
||||
def get_luma_concepts(include_none=False):
|
||||
concepts = []
|
||||
if include_none:
|
||||
concepts.append("None")
|
||||
return concepts + [
|
||||
"truck_left",
|
||||
"pan_right",
|
||||
"pedestal_down",
|
||||
"low_angle",
|
||||
"pedestal_up",
|
||||
"selfie",
|
||||
"pan_left",
|
||||
"roll_right",
|
||||
"zoom_in",
|
||||
"over_the_shoulder",
|
||||
"orbit_right",
|
||||
"orbit_left",
|
||||
"static",
|
||||
"tiny_planet",
|
||||
"high_angle",
|
||||
"bolt_cam",
|
||||
"dolly_zoom",
|
||||
"overhead",
|
||||
"zoom_out",
|
||||
"handheld",
|
||||
"roll_left",
|
||||
"pov",
|
||||
"aerial_drone",
|
||||
"push_in",
|
||||
"crane_down",
|
||||
"truck_right",
|
||||
"tilt_down",
|
||||
"elevator_doors",
|
||||
"tilt_up",
|
||||
"ground_level",
|
||||
"pull_out",
|
||||
"aerial",
|
||||
"crane_up",
|
||||
"eye_level"
|
||||
]
|
||||
|
||||
|
||||
class LumaImageModel(str, Enum):
|
||||
photon_1 = "photon-1"
|
||||
photon_flash_1 = "photon-flash-1"
|
||||
|
||||
|
||||
class LumaVideoModel(str, Enum):
|
||||
ray_2 = "ray-2"
|
||||
ray_flash_2 = "ray-flash-2"
|
||||
ray_1_6 = "ray-1-6"
|
||||
|
||||
|
||||
class LumaAspectRatio(str, Enum):
|
||||
ratio_1_1 = "1:1"
|
||||
ratio_16_9 = "16:9"
|
||||
ratio_9_16 = "9:16"
|
||||
ratio_4_3 = "4:3"
|
||||
ratio_3_4 = "3:4"
|
||||
ratio_21_9 = "21:9"
|
||||
ratio_9_21 = "9:21"
|
||||
|
||||
|
||||
class LumaVideoOutputResolution(str, Enum):
|
||||
res_540p = "540p"
|
||||
res_720p = "720p"
|
||||
res_1080p = "1080p"
|
||||
res_4k = "4k"
|
||||
|
||||
|
||||
class LumaVideoModelOutputDuration(str, Enum):
|
||||
dur_5s = "5s"
|
||||
dur_9s = "9s"
|
||||
|
||||
|
||||
class LumaGenerationType(str, Enum):
|
||||
video = 'video'
|
||||
image = 'image'
|
||||
|
||||
|
||||
class LumaState(str, Enum):
|
||||
queued = "queued"
|
||||
dreaming = "dreaming"
|
||||
completed = "completed"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class LumaAssets(BaseModel):
|
||||
video: Optional[str] = Field(None, description='The URL of the video')
|
||||
image: Optional[str] = Field(None, description='The URL of the image')
|
||||
progress_video: Optional[str] = Field(None, description='The URL of the progress video')
|
||||
|
||||
|
||||
class LumaImageRef(BaseModel):
|
||||
'''Used for image gen'''
|
||||
url: str = Field(..., description='The URL of the image reference')
|
||||
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
|
||||
|
||||
|
||||
class LumaImageReference(BaseModel):
|
||||
'''Used for video gen'''
|
||||
type: Optional[str] = Field('image', description='Input type, defaults to image')
|
||||
url: str = Field(..., description='The URL of the image')
|
||||
|
||||
|
||||
class LumaModifyImageRef(BaseModel):
|
||||
url: str = Field(..., description='The URL of the image reference')
|
||||
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
|
||||
|
||||
|
||||
class LumaCharacterRef(BaseModel):
|
||||
identity0: LumaImageIdentity = Field(..., description='The image identity object')
|
||||
|
||||
|
||||
class LumaImageIdentity(BaseModel):
|
||||
images: list[str] = Field(..., description='The URLs of the image identity')
|
||||
|
||||
|
||||
class LumaGenerationReference(BaseModel):
|
||||
type: str = Field('generation', description='Input type, defaults to generation')
|
||||
id: str = Field(..., description='The ID of the generation')
|
||||
|
||||
|
||||
class LumaKeyframes(BaseModel):
|
||||
frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
|
||||
frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
|
||||
|
||||
|
||||
class LumaConceptObject(BaseModel):
|
||||
key: str = Field(..., description='Camera Concept name')
|
||||
|
||||
|
||||
class LumaImageGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The prompt of the generation')
|
||||
model: LumaImageModel = Field(LumaImageModel.photon_1, description='The image model used for the generation')
|
||||
aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9, description='The aspect ratio of the generation')
|
||||
image_ref: Optional[list[LumaImageRef]] = Field(None, description='List of image reference objects')
|
||||
style_ref: Optional[list[LumaImageRef]] = Field(None, description='List of style reference objects')
|
||||
character_ref: Optional[LumaCharacterRef] = Field(None, description='The image identity object')
|
||||
modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description='The modify image reference object')
|
||||
|
||||
|
||||
class LumaGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The prompt of the generation')
|
||||
model: LumaVideoModel = Field(LumaVideoModel.ray_2, description='The video model used for the generation')
|
||||
duration: Optional[LumaVideoModelOutputDuration] = Field(None, description='The duration of the generation')
|
||||
aspect_ratio: Optional[LumaAspectRatio] = Field(None, description='The aspect ratio of the generation')
|
||||
resolution: Optional[LumaVideoOutputResolution] = Field(None, description='The resolution of the generation')
|
||||
loop: Optional[bool] = Field(None, description='Whether to loop the video')
|
||||
keyframes: Optional[LumaKeyframes] = Field(None, description='The keyframes of the generation')
|
||||
concepts: Optional[list[LumaConceptObject]] = Field(None, description='Camera Concepts to apply to generation')
|
||||
|
||||
|
||||
class LumaGeneration(BaseModel):
|
||||
id: str = Field(..., description='The ID of the generation')
|
||||
generation_type: LumaGenerationType = Field(..., description='Generation type, image or video')
|
||||
state: LumaState = Field(..., description='The state of the generation')
|
||||
failure_reason: Optional[str] = Field(None, description='The reason for the state of the generation')
|
||||
created_at: str = Field(..., description='The date and time when the generation was created')
|
||||
assets: Optional[LumaAssets] = Field(None, description='The assets of the generation')
|
||||
model: str = Field(..., description='The model used for the generation')
|
||||
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation")
|
||||
146
comfy_api_nodes/apis/pixverse_api.py
Normal file
146
comfy_api_nodes/apis/pixverse_api.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
pixverse_templates = {
|
||||
"Microwave": 324641385496960,
|
||||
"Suit Swagger": 328545151283968,
|
||||
"Anything, Robot": 313358700761536,
|
||||
"Subject 3 Fever": 327828816843648,
|
||||
"kiss kiss": 315446315336768,
|
||||
}
|
||||
|
||||
|
||||
class PixverseIO:
|
||||
TEMPLATE = "PIXVERSE_TEMPLATE"
|
||||
|
||||
|
||||
class PixverseStatus(int, Enum):
|
||||
successful = 1
|
||||
generating = 5
|
||||
deleted = 6
|
||||
contents_moderation = 7
|
||||
failed = 8
|
||||
|
||||
|
||||
class PixverseAspectRatio(str, Enum):
|
||||
ratio_16_9 = "16:9"
|
||||
ratio_4_3 = "4:3"
|
||||
ratio_1_1 = "1:1"
|
||||
ratio_3_4 = "3:4"
|
||||
ratio_9_16 = "9:16"
|
||||
|
||||
|
||||
class PixverseQuality(str, Enum):
|
||||
res_360p = "360p"
|
||||
res_540p = "540p"
|
||||
res_720p = "720p"
|
||||
res_1080p = "1080p"
|
||||
|
||||
|
||||
class PixverseDuration(int, Enum):
|
||||
dur_5 = 5
|
||||
dur_8 = 8
|
||||
|
||||
|
||||
class PixverseMotionMode(str, Enum):
|
||||
normal = "normal"
|
||||
fast = "fast"
|
||||
|
||||
|
||||
class PixverseStyle(str, Enum):
|
||||
anime = "anime"
|
||||
animation_3d = "3d_animation"
|
||||
clay = "clay"
|
||||
comic = "comic"
|
||||
cyberpunk = "cyberpunk"
|
||||
|
||||
|
||||
# NOTE: forgoing descriptions for now in return for dev speed
|
||||
class PixverseTextVideoRequest(BaseModel):
|
||||
aspect_ratio: PixverseAspectRatio = Field(...)
|
||||
quality: PixverseQuality = Field(...)
|
||||
duration: PixverseDuration = Field(...)
|
||||
model: Optional[str] = Field("v3.5")
|
||||
motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal)
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
style: Optional[str] = Field(None)
|
||||
template_id: Optional[int] = Field(None)
|
||||
water_mark: Optional[bool] = Field(None)
|
||||
|
||||
|
||||
class PixverseImageVideoRequest(BaseModel):
|
||||
quality: PixverseQuality = Field(...)
|
||||
duration: PixverseDuration = Field(...)
|
||||
img_id: int = Field(...)
|
||||
model: Optional[str] = Field("v3.5")
|
||||
motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal)
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
style: Optional[str] = Field(None)
|
||||
template_id: Optional[int] = Field(None)
|
||||
water_mark: Optional[bool] = Field(None)
|
||||
|
||||
|
||||
class PixverseTransitionVideoRequest(BaseModel):
|
||||
quality: PixverseQuality = Field(...)
|
||||
duration: PixverseDuration = Field(...)
|
||||
first_frame_img: int = Field(...)
|
||||
last_frame_img: int = Field(...)
|
||||
model: Optional[str] = Field("v3.5")
|
||||
motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal)
|
||||
prompt: str = Field(...)
|
||||
# negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
# style: Optional[str] = Field(None)
|
||||
# template_id: Optional[int] = Field(None)
|
||||
# water_mark: Optional[bool] = Field(None)
|
||||
|
||||
|
||||
class PixverseImageUploadResponse(BaseModel):
|
||||
ErrCode: Optional[int] = None
|
||||
ErrMsg: Optional[str] = None
|
||||
Resp: Optional[PixverseImgIdResponseObject] = Field(None, alias='Resp')
|
||||
|
||||
|
||||
class PixverseImgIdResponseObject(BaseModel):
|
||||
img_id: Optional[int] = None
|
||||
|
||||
|
||||
class PixverseVideoResponse(BaseModel):
|
||||
ErrCode: Optional[int] = Field(None)
|
||||
ErrMsg: Optional[str] = Field(None)
|
||||
Resp: Optional[PixverseVideoIdResponseObject] = Field(None)
|
||||
|
||||
|
||||
class PixverseVideoIdResponseObject(BaseModel):
|
||||
video_id: int = Field(..., description='Video_id')
|
||||
|
||||
|
||||
class PixverseGenerationStatusResponse(BaseModel):
|
||||
ErrCode: Optional[int] = Field(None)
|
||||
ErrMsg: Optional[str] = Field(None)
|
||||
Resp: Optional[PixverseGenerationStatusResponseObject] = Field(None)
|
||||
|
||||
|
||||
class PixverseGenerationStatusResponseObject(BaseModel):
|
||||
create_time: Optional[str] = Field(None)
|
||||
id: Optional[int] = Field(None)
|
||||
modify_time: Optional[str] = Field(None)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
outputHeight: Optional[int] = Field(None)
|
||||
outputWidth: Optional[int] = Field(None)
|
||||
prompt: Optional[str] = Field(None)
|
||||
resolution_ratio: Optional[int] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
size: Optional[int] = Field(None)
|
||||
status: Optional[int] = Field(None)
|
||||
style: Optional[str] = Field(None)
|
||||
url: Optional[str] = Field(None)
|
||||
262
comfy_api_nodes/apis/recraft_api.py
Normal file
262
comfy_api_nodes/apis/recraft_api.py
Normal file
@@ -0,0 +1,262 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, conint, confloat
|
||||
|
||||
|
||||
class RecraftColor:
|
||||
def __init__(self, r: int, g: int, b: int):
|
||||
self.color = [r, g, b]
|
||||
|
||||
def create_api_model(self):
|
||||
return RecraftColorObject(rgb=self.color)
|
||||
|
||||
|
||||
class RecraftColorChain:
|
||||
def __init__(self):
|
||||
self.colors: list[RecraftColor] = []
|
||||
|
||||
def get_first(self):
|
||||
if len(self.colors) > 0:
|
||||
return self.colors[0]
|
||||
return None
|
||||
|
||||
def add(self, color: RecraftColor):
|
||||
self.colors.append(color)
|
||||
|
||||
def create_api_model(self):
|
||||
if not self.colors:
|
||||
return None
|
||||
colors_api = [x.create_api_model() for x in self.colors]
|
||||
return colors_api
|
||||
|
||||
def clone(self):
|
||||
c = RecraftColorChain()
|
||||
for color in self.colors:
|
||||
c.add(color)
|
||||
return c
|
||||
|
||||
def clone_and_merge(self, other: RecraftColorChain):
|
||||
c = self.clone()
|
||||
for color in other.colors:
|
||||
c.add(color)
|
||||
return c
|
||||
|
||||
|
||||
class RecraftControls:
|
||||
def __init__(self, colors: RecraftColorChain=None, background_color: RecraftColorChain=None,
|
||||
artistic_level: int=None, no_text: bool=None):
|
||||
self.colors = colors
|
||||
self.background_color = background_color
|
||||
self.artistic_level = artistic_level
|
||||
self.no_text = no_text
|
||||
|
||||
def create_api_model(self):
|
||||
if self.colors is None and self.background_color is None and self.artistic_level is None and self.no_text is None:
|
||||
return None
|
||||
colors_api = None
|
||||
background_color_api = None
|
||||
if self.colors:
|
||||
colors_api = self.colors.create_api_model()
|
||||
if self.background_color:
|
||||
first_background = self.background_color.get_first()
|
||||
background_color_api = first_background.create_api_model() if first_background else None
|
||||
|
||||
return RecraftControlsObject(colors=colors_api, background_color=background_color_api,
|
||||
artistic_level=self.artistic_level, no_text=self.no_text)
|
||||
|
||||
|
||||
class RecraftStyle:
|
||||
def __init__(self, style: str=None, substyle: str=None, style_id: str=None):
|
||||
self.style = style
|
||||
if substyle == "None":
|
||||
substyle = None
|
||||
self.substyle = substyle
|
||||
self.style_id = style_id
|
||||
|
||||
|
||||
class RecraftIO:
|
||||
STYLEV3 = "RECRAFT_V3_STYLE"
|
||||
COLOR = "RECRAFT_COLOR"
|
||||
CONTROLS = "RECRAFT_CONTROLS"
|
||||
|
||||
|
||||
class RecraftStyleV3(str, Enum):
|
||||
#any = 'any' NOTE: this does not work for some reason... why?
|
||||
realistic_image = 'realistic_image'
|
||||
digital_illustration = 'digital_illustration'
|
||||
vector_illustration = 'vector_illustration'
|
||||
logo_raster = 'logo_raster'
|
||||
|
||||
|
||||
def get_v3_substyles(style_v3: str, include_none=True) -> list[str]:
|
||||
substyles: list[str] = []
|
||||
if include_none:
|
||||
substyles.append("None")
|
||||
return substyles + dict_recraft_substyles_v3.get(style_v3, [])
|
||||
|
||||
|
||||
dict_recraft_substyles_v3 = {
|
||||
RecraftStyleV3.realistic_image: [
|
||||
"b_and_w",
|
||||
"enterprise",
|
||||
"evening_light",
|
||||
"faded_nostalgia",
|
||||
"forest_life",
|
||||
"hard_flash",
|
||||
"hdr",
|
||||
"motion_blur",
|
||||
"mystic_naturalism",
|
||||
"natural_light",
|
||||
"natural_tones",
|
||||
"organic_calm",
|
||||
"real_life_glow",
|
||||
"retro_realism",
|
||||
"retro_snapshot",
|
||||
"studio_portrait",
|
||||
"urban_drama",
|
||||
"village_realism",
|
||||
"warm_folk"
|
||||
],
|
||||
RecraftStyleV3.digital_illustration: [
|
||||
"2d_art_poster",
|
||||
"2d_art_poster_2",
|
||||
"antiquarian",
|
||||
"bold_fantasy",
|
||||
"child_book",
|
||||
"child_books",
|
||||
"cover",
|
||||
"crosshatch",
|
||||
"digital_engraving",
|
||||
"engraving_color",
|
||||
"expressionism",
|
||||
"freehand_details",
|
||||
"grain",
|
||||
"grain_20",
|
||||
"graphic_intensity",
|
||||
"hand_drawn",
|
||||
"hand_drawn_outline",
|
||||
"handmade_3d",
|
||||
"hard_comics",
|
||||
"infantile_sketch",
|
||||
"long_shadow",
|
||||
"modern_folk",
|
||||
"multicolor",
|
||||
"neon_calm",
|
||||
"noir",
|
||||
"nostalgic_pastel",
|
||||
"outline_details",
|
||||
"pastel_gradient",
|
||||
"pastel_sketch",
|
||||
"pixel_art",
|
||||
"plastic",
|
||||
"pop_art",
|
||||
"pop_renaissance",
|
||||
"seamless",
|
||||
"street_art",
|
||||
"tablet_sketch",
|
||||
"urban_glow",
|
||||
"urban_sketching",
|
||||
"vanilla_dreams",
|
||||
"young_adult_book",
|
||||
"young_adult_book_2"
|
||||
],
|
||||
RecraftStyleV3.vector_illustration: [
|
||||
"bold_stroke",
|
||||
"chemistry",
|
||||
"colored_stencil",
|
||||
"contour_pop_art",
|
||||
"cosmics",
|
||||
"cutout",
|
||||
"depressive",
|
||||
"editorial",
|
||||
"emotional_flat",
|
||||
"engraving",
|
||||
"infographical",
|
||||
"line_art",
|
||||
"line_circuit",
|
||||
"linocut",
|
||||
"marker_outline",
|
||||
"mosaic",
|
||||
"naivector",
|
||||
"roundish_flat",
|
||||
"seamless",
|
||||
"segmented_colors",
|
||||
"sharp_contrast",
|
||||
"thin",
|
||||
"vector_photo",
|
||||
"vivid_shapes"
|
||||
],
|
||||
RecraftStyleV3.logo_raster: [
|
||||
"emblem_graffiti",
|
||||
"emblem_pop_art",
|
||||
"emblem_punk",
|
||||
"emblem_stamp",
|
||||
"emblem_vintage"
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class RecraftModel(str, Enum):
|
||||
recraftv3 = 'recraftv3'
|
||||
recraftv2 = 'recraftv2'
|
||||
|
||||
|
||||
class RecraftImageSize(str, Enum):
|
||||
res_1024x1024 = '1024x1024'
|
||||
res_1365x1024 = '1365x1024'
|
||||
res_1024x1365 = '1024x1365'
|
||||
res_1536x1024 = '1536x1024'
|
||||
res_1024x1536 = '1024x1536'
|
||||
res_1820x1024 = '1820x1024'
|
||||
res_1024x1820 = '1024x1820'
|
||||
res_1024x2048 = '1024x2048'
|
||||
res_2048x1024 = '2048x1024'
|
||||
res_1434x1024 = '1434x1024'
|
||||
res_1024x1434 = '1024x1434'
|
||||
res_1024x1280 = '1024x1280'
|
||||
res_1280x1024 = '1280x1024'
|
||||
res_1024x1707 = '1024x1707'
|
||||
res_1707x1024 = '1707x1024'
|
||||
|
||||
|
||||
class RecraftColorObject(BaseModel):
|
||||
rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')
|
||||
|
||||
|
||||
class RecraftControlsObject(BaseModel):
|
||||
colors: Optional[list[RecraftColorObject]] = Field(None, description='An array of preferable colors')
|
||||
background_color: Optional[RecraftColorObject] = Field(None, description='Use given color as a desired background color')
|
||||
no_text: Optional[bool] = Field(None, description='Do not embed text layouts')
|
||||
artistic_level: Optional[conint(ge=0, le=5)] = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].')
|
||||
|
||||
|
||||
class RecraftImageGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt describing the image to generate')
|
||||
size: Optional[RecraftImageSize] = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||
n: conint(ge=1, le=6) = Field(..., description='The number of images to generate')
|
||||
negative_prompt: Optional[str] = Field(None, description='A text description of undesired elements on an image')
|
||||
model: Optional[RecraftModel] = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
|
||||
style: Optional[str] = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
|
||||
substyle: Optional[str] = Field(None, description='The substyle to apply to the generated image, depending on the style input')
|
||||
controls: Optional[RecraftControlsObject] = Field(None, description='A set of custom parameters to tweak generation process')
|
||||
style_id: Optional[str] = Field(None, description='Use a previously uploaded style as a reference; UUID')
|
||||
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
|
||||
random_seed: Optional[int] = Field(None, description="Seed for video generation")
|
||||
# text_layout
|
||||
|
||||
|
||||
class RecraftReturnedObject(BaseModel):
|
||||
image_id: str = Field(..., description='Unique identifier for the generated image')
|
||||
url: str = Field(..., description='URL to access the generated image')
|
||||
|
||||
|
||||
class RecraftImageGenerationResponse(BaseModel):
|
||||
created: int = Field(..., description='Unix timestamp when the generation was created')
|
||||
credits: int = Field(..., description='Number of credits used for the generation')
|
||||
data: Optional[list[RecraftReturnedObject]] = Field(None, description='Array of generated image information')
|
||||
image: Optional[RecraftReturnedObject] = Field(None, description='Single generated image')
|
||||
125
comfy_api_nodes/apis/request_logger.py
Normal file
125
comfy_api_nodes/apis/request_logger.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import os
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import folder_paths
|
||||
|
||||
# Get the logger instance
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_log_directory():
|
||||
"""
|
||||
Ensures the API log directory exists within ComfyUI's temp directory
|
||||
and returns its path.
|
||||
"""
|
||||
base_temp_dir = folder_paths.get_temp_directory()
|
||||
log_dir = os.path.join(base_temp_dir, "api_logs")
|
||||
try:
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating API log directory {log_dir}: {e}")
|
||||
# Fallback to base temp directory if sub-directory creation fails
|
||||
return base_temp_dir
|
||||
return log_dir
|
||||
|
||||
def _format_data_for_logging(data):
|
||||
"""Helper to format data (dict, str, bytes) for logging."""
|
||||
if isinstance(data, bytes):
|
||||
try:
|
||||
return data.decode('utf-8') # Try to decode as text
|
||||
except UnicodeDecodeError:
|
||||
return f"[Binary data of length {len(data)} bytes]"
|
||||
elif isinstance(data, (dict, list)):
|
||||
try:
|
||||
return json.dumps(data, indent=2, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(data) # Fallback for non-serializable objects
|
||||
return str(data)
|
||||
|
||||
def log_request_response(
|
||||
operation_id: str,
|
||||
request_method: str,
|
||||
request_url: str,
|
||||
request_headers: dict | None = None,
|
||||
request_params: dict | None = None,
|
||||
request_data: any = None,
|
||||
response_status_code: int | None = None,
|
||||
response_headers: dict | None = None,
|
||||
response_content: any = None,
|
||||
error_message: str | None = None
|
||||
):
|
||||
"""
|
||||
Logs API request and response details to a file in the temp/api_logs directory.
|
||||
"""
|
||||
log_dir = get_log_directory()
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log"
|
||||
filepath = os.path.join(log_dir, filename)
|
||||
|
||||
log_content = []
|
||||
|
||||
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
|
||||
log_content.append(f"Operation ID: {operation_id}")
|
||||
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
|
||||
log_content.append(f"Method: {request_method}")
|
||||
log_content.append(f"URL: {request_url}")
|
||||
if request_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
|
||||
if request_params:
|
||||
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
|
||||
if request_data:
|
||||
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
|
||||
|
||||
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
|
||||
if response_status_code is not None:
|
||||
log_content.append(f"Status Code: {response_status_code}")
|
||||
if response_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
|
||||
if response_content:
|
||||
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
|
||||
if error_message:
|
||||
log_content.append(f"Error:\n{error_message}")
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_content))
|
||||
logger.debug(f"API log saved to: {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing API log to {filepath}: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Example usage (for testing the logger directly)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
# Mock folder_paths for direct execution if not running within ComfyUI full context
|
||||
if not hasattr(folder_paths, 'get_temp_directory'):
|
||||
class MockFolderPaths:
|
||||
def get_temp_directory(self):
|
||||
# Create a local temp dir for testing if needed
|
||||
p = os.path.join(os.path.dirname(__file__), 'temp_test_logs')
|
||||
os.makedirs(p, exist_ok=True)
|
||||
return p
|
||||
folder_paths = MockFolderPaths()
|
||||
|
||||
log_request_response(
|
||||
operation_id="test_operation_get",
|
||||
request_method="GET",
|
||||
request_url="https://api.example.com/test",
|
||||
request_headers={"Authorization": "Bearer testtoken"},
|
||||
request_params={"param1": "value1"},
|
||||
response_status_code=200,
|
||||
response_content={"message": "Success!"}
|
||||
)
|
||||
log_request_response(
|
||||
operation_id="test_operation_post_error",
|
||||
request_method="POST",
|
||||
request_url="https://api.example.com/submit",
|
||||
request_data={"key": "value", "nested": {"num": 123}},
|
||||
error_message="Connection timed out"
|
||||
)
|
||||
log_request_response(
|
||||
operation_id="test_binary_response",
|
||||
request_method="GET",
|
||||
request_url="https://api.example.com/image.png",
|
||||
response_status_code=200,
|
||||
response_content=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR...' # Sample binary data
|
||||
)
|
||||
127
comfy_api_nodes/apis/stability_api.py
Normal file
127
comfy_api_nodes/apis/stability_api.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, confloat
|
||||
|
||||
|
||||
class StabilityFormat(str, Enum):
|
||||
png = 'png'
|
||||
jpeg = 'jpeg'
|
||||
webp = 'webp'
|
||||
|
||||
|
||||
class StabilityAspectRatio(str, Enum):
|
||||
ratio_1_1 = "1:1"
|
||||
ratio_16_9 = "16:9"
|
||||
ratio_9_16 = "9:16"
|
||||
ratio_3_2 = "3:2"
|
||||
ratio_2_3 = "2:3"
|
||||
ratio_5_4 = "5:4"
|
||||
ratio_4_5 = "4:5"
|
||||
ratio_21_9 = "21:9"
|
||||
ratio_9_21 = "9:21"
|
||||
|
||||
|
||||
def get_stability_style_presets(include_none=True):
|
||||
presets = []
|
||||
if include_none:
|
||||
presets.append("None")
|
||||
return presets + [x.value for x in StabilityStylePreset]
|
||||
|
||||
|
||||
class StabilityStylePreset(str, Enum):
|
||||
_3d_model = "3d-model"
|
||||
analog_film = "analog-film"
|
||||
anime = "anime"
|
||||
cinematic = "cinematic"
|
||||
comic_book = "comic-book"
|
||||
digital_art = "digital-art"
|
||||
enhance = "enhance"
|
||||
fantasy_art = "fantasy-art"
|
||||
isometric = "isometric"
|
||||
line_art = "line-art"
|
||||
low_poly = "low-poly"
|
||||
modeling_compound = "modeling-compound"
|
||||
neon_punk = "neon-punk"
|
||||
origami = "origami"
|
||||
photographic = "photographic"
|
||||
pixel_art = "pixel-art"
|
||||
tile_texture = "tile-texture"
|
||||
|
||||
|
||||
class Stability_SD3_5_Model(str, Enum):
|
||||
sd3_5_large = "sd3.5-large"
|
||||
# sd3_5_large_turbo = "sd3.5-large-turbo"
|
||||
sd3_5_medium = "sd3.5-medium"
|
||||
|
||||
|
||||
class Stability_SD3_5_GenerationMode(str, Enum):
|
||||
text_to_image = "text-to-image"
|
||||
image_to_image = "image-to-image"
|
||||
|
||||
|
||||
class StabilityStable3_5Request(BaseModel):
|
||||
model: str = Field(...)
|
||||
mode: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
aspect_ratio: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
style_preset: Optional[str] = Field(None)
|
||||
cfg_scale: float = Field(...)
|
||||
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
|
||||
|
||||
|
||||
class StabilityUpscaleConservativeRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
creativity: Optional[confloat(ge=0.2, le=0.5)] = Field(None)
|
||||
|
||||
|
||||
class StabilityUpscaleCreativeRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
creativity: Optional[confloat(ge=0.1, le=0.5)] = Field(None)
|
||||
style_preset: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityStableUltraRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
aspect_ratio: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
style_preset: Optional[str] = Field(None)
|
||||
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
|
||||
|
||||
|
||||
class StabilityStableUltraResponse(BaseModel):
|
||||
image: Optional[str] = Field(None)
|
||||
finish_reason: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class StabilityResultsGetResponse(BaseModel):
|
||||
image: Optional[str] = Field(None)
|
||||
finish_reason: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
id: Optional[str] = Field(None)
|
||||
name: Optional[str] = Field(None)
|
||||
errors: Optional[list[str]] = Field(None)
|
||||
status: Optional[str] = Field(None)
|
||||
result: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityAsyncResponse(BaseModel):
|
||||
id: Optional[str] = Field(None)
|
||||
10
comfy_api_nodes/canary.py
Normal file
10
comfy_api_nodes/canary.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import av
|
||||
|
||||
ver = av.__version__.split(".")
|
||||
if int(ver[0]) < 14:
|
||||
raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
|
||||
|
||||
if int(ver[0]) == 14 and int(ver[1]) < 2:
|
||||
raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
116
comfy_api_nodes/mapper_utils.py
Normal file
116
comfy_api_nodes/mapper_utils.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic import BaseModel
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, InputTypeOptions
|
||||
|
||||
NodeInput = tuple[IO, InputTypeOptions]
|
||||
|
||||
|
||||
def _create_base_config(field_info: FieldInfo) -> InputTypeOptions:
|
||||
config = {}
|
||||
if hasattr(field_info, "default") and field_info.default is not PydanticUndefined:
|
||||
config["default"] = field_info.default
|
||||
if hasattr(field_info, "description") and field_info.description is not None:
|
||||
config["tooltip"] = field_info.description
|
||||
return config
|
||||
|
||||
|
||||
def _get_number_constraints_config(field_info: FieldInfo) -> dict:
|
||||
config = {}
|
||||
if hasattr(field_info, "metadata"):
|
||||
metadata = field_info.metadata
|
||||
for constraint in metadata:
|
||||
if hasattr(constraint, "ge"):
|
||||
config["min"] = constraint.ge
|
||||
if hasattr(constraint, "le"):
|
||||
config["max"] = constraint.le
|
||||
if hasattr(constraint, "multiple_of"):
|
||||
config["step"] = constraint.multiple_of
|
||||
return config
|
||||
|
||||
|
||||
def _model_field_to_image_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||
return IO.IMAGE, {
|
||||
**_create_base_config(field_info),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _model_field_to_string_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||
return IO.STRING, {
|
||||
**_create_base_config(field_info),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _model_field_to_float_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||
return IO.FLOAT, {
|
||||
**_create_base_config(field_info),
|
||||
**_get_number_constraints_config(field_info),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _model_field_to_int_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||
return IO.INT, {
|
||||
**_create_base_config(field_info),
|
||||
**_get_number_constraints_config(field_info),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _model_field_to_combo_input(
|
||||
field_info: FieldInfo, enum_type: type[Enum] = None, **kwargs
|
||||
) -> NodeInput:
|
||||
combo_config = {}
|
||||
if enum_type is not None:
|
||||
combo_config["options"] = [option.value for option in enum_type]
|
||||
combo_config = {
|
||||
**combo_config,
|
||||
**_create_base_config(field_info),
|
||||
**kwargs,
|
||||
}
|
||||
return IO.COMBO, combo_config
|
||||
|
||||
|
||||
def model_field_to_node_input(
|
||||
input_type: IO, base_model: type[BaseModel], field_name: str, **kwargs
|
||||
) -> NodeInput:
|
||||
"""
|
||||
Maps a field from a Pydantic model to a Comfy node input.
|
||||
|
||||
Args:
|
||||
input_type: The type of the input.
|
||||
base_model: The Pydantic model to map the field from.
|
||||
field_name: The name of the field to map.
|
||||
**kwargs: Additional key/values to include in the input options.
|
||||
|
||||
Note:
|
||||
For combo inputs, pass an `Enum` to the `enum_type` keyword argument to populate the options automatically.
|
||||
|
||||
Example:
|
||||
>>> model_field_to_node_input(IO.STRING, MyModel, "my_field", multiline=True)
|
||||
>>> model_field_to_node_input(IO.COMBO, MyModel, "my_field", enum_type=MyEnum)
|
||||
>>> model_field_to_node_input(IO.FLOAT, MyModel, "my_field", slider=True)
|
||||
"""
|
||||
field_info: FieldInfo = base_model.model_fields[field_name]
|
||||
result: NodeInput
|
||||
|
||||
if input_type == IO.IMAGE:
|
||||
result = _model_field_to_image_input(field_info, **kwargs)
|
||||
elif input_type == IO.STRING:
|
||||
result = _model_field_to_string_input(field_info, **kwargs)
|
||||
elif input_type == IO.FLOAT:
|
||||
result = _model_field_to_float_input(field_info, **kwargs)
|
||||
elif input_type == IO.INT:
|
||||
result = _model_field_to_int_input(field_info, **kwargs)
|
||||
elif input_type == IO.COMBO:
|
||||
result = _model_field_to_combo_input(field_info, **kwargs)
|
||||
else:
|
||||
message = f"Invalid input type: {input_type}"
|
||||
raise ValueError(message)
|
||||
|
||||
return result
|
||||
@@ -1,449 +0,0 @@
|
||||
import base64
|
||||
import io
|
||||
import math
|
||||
from inspect import cleandoc
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api_nodes.apis import (
|
||||
OpenAIImageEditRequest,
|
||||
OpenAIImageGenerationRequest,
|
||||
OpenAIImageGenerationResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import ApiEndpoint, HttpMethod, SynchronousOperation
|
||||
|
||||
|
||||
def downscale_input(image):
|
||||
samples = image.movedim(-1,1)
|
||||
#downscaling input images to roughly the same size as the outputs
|
||||
total = int(1536 * 1024)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
if scale_by >= 1:
|
||||
return image
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||
s = s.movedim(1,-1)
|
||||
return s
|
||||
|
||||
def validate_and_cast_response(response):
|
||||
# validate raw JSON response
|
||||
data = response.data
|
||||
if not data or len(data) == 0:
|
||||
raise Exception("No images returned from API endpoint")
|
||||
|
||||
# Initialize list to store image tensors
|
||||
image_tensors = []
|
||||
|
||||
# Process each image in the data array
|
||||
for image_data in data:
|
||||
image_url = image_data.url
|
||||
b64_data = image_data.b64_json
|
||||
|
||||
if not image_url and not b64_data:
|
||||
raise Exception("No image was generated in the response")
|
||||
|
||||
if b64_data:
|
||||
img_data = base64.b64decode(b64_data)
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
|
||||
elif image_url:
|
||||
img_response = requests.get(image_url)
|
||||
if img_response.status_code != 200:
|
||||
raise Exception("Failed to download the image")
|
||||
img = Image.open(io.BytesIO(img_response.content))
|
||||
|
||||
img = img.convert("RGBA")
|
||||
|
||||
# Convert to numpy array, normalize to float32 between 0 and 1
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
img_tensor = torch.from_numpy(img_array)
|
||||
|
||||
# Add to list of tensors
|
||||
image_tensors.append(img_tensor)
|
||||
|
||||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
class OpenAIDalle2(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
|
||||
|
||||
Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived,
|
||||
so download or cache results if you need to keep them.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for DALL·E",
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (IO.INT, {
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2**31-1,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "not implemented yet in backend",
|
||||
}),
|
||||
"size": (IO.COMBO, {
|
||||
"options": ["256x256", "512x512", "1024x1024"],
|
||||
"default": "1024x1024",
|
||||
"tooltip": "Image size",
|
||||
}),
|
||||
"n": (IO.INT, {
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "How many images to generate",
|
||||
}),
|
||||
"image": (IO.IMAGE, {
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image for image editing.",
|
||||
}),
|
||||
"mask": (IO.MASK, {
|
||||
"default": None,
|
||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||
}),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG"
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(self, prompt, seed=0, image=None, mask=None, n=1, size="1024x1024", auth_token=None):
|
||||
model = "dall-e-2"
|
||||
path = "/proxy/openai/images/generations"
|
||||
request_class = OpenAIImageGenerationRequest
|
||||
img_binary = None
|
||||
|
||||
if image is not None and mask is not None:
|
||||
path = "/proxy/openai/images/edits"
|
||||
request_class = OpenAIImageEditRequest
|
||||
|
||||
input_tensor = image.squeeze().cpu()
|
||||
height, width, channels = input_tensor.shape
|
||||
rgba_tensor = torch.ones(height, width, 4, device="cpu")
|
||||
rgba_tensor[:, :, :channels] = input_tensor
|
||||
|
||||
if mask.shape[1:] != image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
rgba_tensor[:,:,3] = (1-mask.squeeze().cpu())
|
||||
|
||||
rgba_tensor = downscale_input(rgba_tensor.unsqueeze(0)).squeeze()
|
||||
|
||||
image_np = (rgba_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format='PNG')
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr#.getvalue()
|
||||
img_binary.name = "image.png"
|
||||
elif image is not None or mask is not None:
|
||||
raise Exception("Dall-E 2 image editing requires an image AND a mask")
|
||||
|
||||
# Build the operation
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=request_class,
|
||||
response_model=OpenAIImageGenerationResponse
|
||||
),
|
||||
request=request_class(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
size=size,
|
||||
seed=seed,
|
||||
),
|
||||
files={
|
||||
"image": img_binary,
|
||||
} if img_binary else None,
|
||||
auth_token=auth_token
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response)
|
||||
return (img_tensor,)
|
||||
|
||||
class OpenAIDalle3(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
|
||||
|
||||
Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived,
|
||||
so download or cache results if you need to keep them.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for DALL·E",
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (IO.INT, {
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2**31-1,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "not implemented yet in backend",
|
||||
}),
|
||||
"quality" : (IO.COMBO, {
|
||||
"options": ["standard","hd"],
|
||||
"default": "standard",
|
||||
"tooltip": "Image quality",
|
||||
}),
|
||||
"style": (IO.COMBO, {
|
||||
"options": ["natural","vivid"],
|
||||
"default": "natural",
|
||||
"tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.",
|
||||
}),
|
||||
"size": (IO.COMBO, {
|
||||
"options": ["1024x1024", "1024x1792", "1792x1024"],
|
||||
"default": "1024x1024",
|
||||
"tooltip": "Image size",
|
||||
}),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG"
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(self, prompt, seed=0, style="natural", quality="standard", size="1024x1024", auth_token=None):
|
||||
model = "dall-e-3"
|
||||
|
||||
# build the operation
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/openai/images/generations",
|
||||
method=HttpMethod.POST,
|
||||
request_model=OpenAIImageGenerationRequest,
|
||||
response_model=OpenAIImageGenerationResponse
|
||||
),
|
||||
request=OpenAIImageGenerationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
size=size,
|
||||
style=style,
|
||||
seed=seed,
|
||||
),
|
||||
auth_token=auth_token
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response)
|
||||
return (img_tensor,)
|
||||
|
||||
class OpenAIGPTImage1(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
|
||||
|
||||
Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived,
|
||||
so download or cache results if you need to keep them.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for GPT Image 1",
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (IO.INT, {
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2**31-1,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "not implemented yet in backend",
|
||||
}),
|
||||
"quality": (IO.COMBO, {
|
||||
"options": ["low","medium","high"],
|
||||
"default": "low",
|
||||
"tooltip": "Image quality, affects cost and generation time.",
|
||||
}),
|
||||
"background": (IO.COMBO, {
|
||||
"options": ["opaque","transparent"],
|
||||
"default": "opaque",
|
||||
"tooltip": "Return image with or without background",
|
||||
}),
|
||||
"size": (IO.COMBO, {
|
||||
"options": ["auto", "1024x1024", "1024x1536", "1536x1024"],
|
||||
"default": "auto",
|
||||
"tooltip": "Image size",
|
||||
}),
|
||||
"n": (IO.INT, {
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "How many images to generate",
|
||||
}),
|
||||
"image": (IO.IMAGE, {
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image for image editing.",
|
||||
}),
|
||||
"mask": (IO.MASK, {
|
||||
"default": None,
|
||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||
}),
|
||||
"moderation": (IO.COMBO, {
|
||||
"options": ["low","auto"],
|
||||
"default": "low",
|
||||
"tooltip": "Moderation level",
|
||||
}),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG"
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(self, prompt, seed=0, quality="low", background="opaque", image=None, mask=None, n=1, size="1024x1024", auth_token=None, moderation="low"):
|
||||
model = "gpt-image-1"
|
||||
path = "/proxy/openai/images/generations"
|
||||
request_class = OpenAIImageGenerationRequest
|
||||
img_binaries = []
|
||||
mask_binary = None
|
||||
files = []
|
||||
|
||||
if image is not None:
|
||||
path = "/proxy/openai/images/edits"
|
||||
request_class = OpenAIImageEditRequest
|
||||
|
||||
batch_size = image.shape[0]
|
||||
|
||||
|
||||
for i in range(batch_size):
|
||||
single_image = image[i:i+1]
|
||||
scaled_image = downscale_input(single_image).squeeze()
|
||||
|
||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format='PNG')
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr
|
||||
img_binary.name = f"image_{i}.png"
|
||||
|
||||
img_binaries.append(img_binary)
|
||||
if batch_size == 1:
|
||||
files.append(("image", img_binary))
|
||||
else:
|
||||
files.append(("image[]", img_binary))
|
||||
|
||||
if mask is not None:
|
||||
if image.shape[0] != 1:
|
||||
raise Exception("Cannot use a mask with multiple image")
|
||||
if image is None:
|
||||
raise Exception("Cannot use a mask without an input image")
|
||||
if mask.shape[1:] != image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
batch, height, width = mask.shape
|
||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||
rgba_mask[:,:,3] = (1-mask.squeeze().cpu())
|
||||
|
||||
scaled_mask = downscale_input(rgba_mask.unsqueeze(0)).squeeze()
|
||||
|
||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_img_byte_arr = io.BytesIO()
|
||||
mask_img.save(mask_img_byte_arr, format='PNG')
|
||||
mask_img_byte_arr.seek(0)
|
||||
mask_binary = mask_img_byte_arr
|
||||
mask_binary.name = "mask.png"
|
||||
files.append(("mask", mask_binary))
|
||||
|
||||
|
||||
# Build the operation
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=request_class,
|
||||
response_model=OpenAIImageGenerationResponse
|
||||
),
|
||||
request=request_class(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
seed=seed,
|
||||
size=size,
|
||||
moderation=moderation,
|
||||
),
|
||||
files=files if files else None,
|
||||
auth_token=auth_token
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"OpenAIDalle2": OpenAIDalle2,
|
||||
"OpenAIDalle3": OpenAIDalle3,
|
||||
"OpenAIGPTImage1": OpenAIGPTImage1,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"OpenAIDalle2": "OpenAI DALL·E 2",
|
||||
"OpenAIDalle3": "OpenAI DALL·E 3",
|
||||
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
||||
}
|
||||
931
comfy_api_nodes/nodes_bfl.py
Normal file
931
comfy_api_nodes/nodes_bfl.py
Normal file
@@ -0,0 +1,931 @@
|
||||
import io
|
||||
from inspect import cleandoc
|
||||
from typing import Union
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api_nodes.apis.bfl_api import (
|
||||
BFLStatus,
|
||||
BFLFluxExpandImageRequest,
|
||||
BFLFluxFillImageRequest,
|
||||
BFLFluxCannyImageRequest,
|
||||
BFLFluxDepthImageRequest,
|
||||
BFLFluxProGenerateRequest,
|
||||
BFLFluxProUltraGenerateRequest,
|
||||
BFLFluxProGenerateResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
validate_aspect_ratio,
|
||||
process_image_response,
|
||||
resize_mask_to_image,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
import base64
|
||||
import time
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
def convert_mask_to_image(mask: torch.Tensor):
|
||||
"""
|
||||
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
|
||||
"""
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = torch.cat([mask]*3, dim=-1)
|
||||
return mask
|
||||
|
||||
|
||||
def handle_bfl_synchronous_operation(
|
||||
operation: SynchronousOperation,
|
||||
timeout_bfl_calls=360,
|
||||
node_id: Union[str, None] = None,
|
||||
):
|
||||
response_api: BFLFluxProGenerateResponse = operation.execute()
|
||||
return _poll_until_generated(
|
||||
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
|
||||
)
|
||||
|
||||
|
||||
def _poll_until_generated(
|
||||
polling_url: str, timeout=360, node_id: Union[str, None] = None
|
||||
):
|
||||
# used bfl-comfy-nodes to verify code implementation:
|
||||
# https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
|
||||
start_time = time.time()
|
||||
retries_404 = 0
|
||||
max_retries_404 = 5
|
||||
retry_404_seconds = 2
|
||||
retry_202_seconds = 2
|
||||
retry_pending_seconds = 1
|
||||
request = requests.Request(method=HttpMethod.GET, url=polling_url)
|
||||
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
|
||||
while True:
|
||||
if node_id:
|
||||
time_elapsed = time.time() - start_time
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generating ({time_elapsed:.0f}s)", node_id
|
||||
)
|
||||
|
||||
response = requests.Session().send(request.prepare())
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result["status"] == BFLStatus.ready:
|
||||
img_url = result["result"]["sample"]
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {img_url}", node_id
|
||||
)
|
||||
img_response = requests.get(img_url)
|
||||
return process_image_response(img_response)
|
||||
elif result["status"] in [
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
]:
|
||||
status = result["status"]
|
||||
raise Exception(
|
||||
f"BFL API did not return an image due to: {status}."
|
||||
)
|
||||
elif result["status"] == BFLStatus.error:
|
||||
raise Exception(f"BFL API encountered an error: {result}.")
|
||||
elif result["status"] == BFLStatus.pending:
|
||||
time.sleep(retry_pending_seconds)
|
||||
continue
|
||||
elif response.status_code == 404:
|
||||
if retries_404 < max_retries_404:
|
||||
retries_404 += 1
|
||||
time.sleep(retry_404_seconds)
|
||||
continue
|
||||
raise Exception(
|
||||
f"BFL API could not find task after {max_retries_404} tries."
|
||||
)
|
||||
elif response.status_code == 202:
|
||||
time.sleep(retry_202_seconds)
|
||||
elif time.time() - start_time > timeout:
|
||||
raise Exception(
|
||||
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"BFL API encountered an error: {response.json()}")
|
||||
|
||||
def convert_image_to_base64(image: torch.Tensor):
|
||||
scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048)
|
||||
# remove batch dimension if present
|
||||
if len(scaled_image.shape) > 3:
|
||||
scaled_image = scaled_image[0]
|
||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
return base64.b64encode(img_byte_arr.getvalue()).decode()
|
||||
|
||||
|
||||
class FluxProUltraImageNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
|
||||
"""
|
||||
|
||||
MINIMUM_RATIO = 1 / 4
|
||||
MAXIMUM_RATIO = 4 / 1
|
||||
MINIMUM_RATIO_STR = "1:4"
|
||||
MAXIMUM_RATIO_STR = "4:1"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "16:9",
|
||||
"tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.",
|
||||
},
|
||||
),
|
||||
"raw": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "When True, generate less processed, more natural-looking images.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image_prompt": (IO.IMAGE,),
|
||||
"image_prompt_strength": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.1,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
"tooltip": "Blend between the prompt and the image prompt.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, aspect_ratio: str):
|
||||
try:
|
||||
validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=cls.MINIMUM_RATIO,
|
||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
||||
)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
return True
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
prompt_upsampling=False,
|
||||
raw=False,
|
||||
seed=0,
|
||||
image_prompt=None,
|
||||
image_prompt_strength=0.1,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if image_prompt is None:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.1-ultra/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxProUltraGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxProUltraGenerateRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
seed=seed,
|
||||
aspect_ratio=validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=self.MINIMUM_RATIO,
|
||||
maximum_ratio=self.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=self.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
|
||||
),
|
||||
raw=raw,
|
||||
image_prompt=(
|
||||
image_prompt
|
||||
if image_prompt is None
|
||||
else convert_image_to_base64(image_prompt)
|
||||
),
|
||||
image_prompt_strength=(
|
||||
None if image_prompt is None else round(image_prompt_strength, 2)
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
|
||||
class FluxProImageNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
"width": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1024,
|
||||
"min": 256,
|
||||
"max": 1440,
|
||||
"step": 32,
|
||||
},
|
||||
),
|
||||
"height": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 768,
|
||||
"min": 256,
|
||||
"max": 1440,
|
||||
"step": 32,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image_prompt": (IO.IMAGE,),
|
||||
# "image_prompt_strength": (
|
||||
# IO.FLOAT,
|
||||
# {
|
||||
# "default": 0.1,
|
||||
# "min": 0.0,
|
||||
# "max": 1.0,
|
||||
# "step": 0.01,
|
||||
# "tooltip": "Blend between the prompt and the image prompt.",
|
||||
# },
|
||||
# ),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_upsampling,
|
||||
width: int,
|
||||
height: int,
|
||||
seed=0,
|
||||
image_prompt=None,
|
||||
# image_prompt_strength=0.1,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
image_prompt = (
|
||||
image_prompt
|
||||
if image_prompt is None
|
||||
else convert_image_to_base64(image_prompt)
|
||||
)
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.1/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxProGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxProGenerateRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
width=width,
|
||||
height=height,
|
||||
seed=seed,
|
||||
image_prompt=image_prompt,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
class FluxProExpandNode(ComfyNodeABC):
|
||||
"""
|
||||
Outpaints image based on prompt.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
"top": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2048,
|
||||
"tooltip": "Number of pixels to expand at the top of the image"
|
||||
},
|
||||
),
|
||||
"bottom": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2048,
|
||||
"tooltip": "Number of pixels to expand at the bottom of the image"
|
||||
},
|
||||
),
|
||||
"left": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2048,
|
||||
"tooltip": "Number of pixels to expand at the left side of the image"
|
||||
},
|
||||
),
|
||||
"right": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2048,
|
||||
"tooltip": "Number of pixels to expand at the right side of the image"
|
||||
},
|
||||
),
|
||||
"guidance": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 60,
|
||||
"min": 1.5,
|
||||
"max": 100,
|
||||
"tooltip": "Guidance strength for the image generation process"
|
||||
},
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 50,
|
||||
"min": 15,
|
||||
"max": 50,
|
||||
"tooltip": "Number of steps for the image generation process"
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
top: int,
|
||||
bottom: int,
|
||||
left: int,
|
||||
right: int,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
image = convert_image_to_base64(image)
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-expand/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxExpandImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxExpandImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
top=top,
|
||||
bottom=bottom,
|
||||
left=left,
|
||||
right=right,
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
image=image,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
|
||||
class FluxProFillNode(ComfyNodeABC):
|
||||
"""
|
||||
Inpaints image based on mask and prompt.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
"mask": (IO.MASK,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
"guidance": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 60,
|
||||
"min": 1.5,
|
||||
"max": 100,
|
||||
"tooltip": "Guidance strength for the image generation process"
|
||||
},
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 50,
|
||||
"min": 15,
|
||||
"max": 50,
|
||||
"tooltip": "Number of steps for the image generation process"
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# prepare mask
|
||||
mask = resize_mask_to_image(mask, image)
|
||||
mask = convert_image_to_base64(convert_mask_to_image(mask))
|
||||
# make sure image will have alpha channel removed
|
||||
image = convert_image_to_base64(image[:, :, :, :3])
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-fill/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxFillImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxFillImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
image=image,
|
||||
mask=mask,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
class FluxProCannyNode(ComfyNodeABC):
|
||||
"""
|
||||
Generate image using a control image (canny).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"control_image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
"canny_low_threshold": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.1,
|
||||
"min": 0.01,
|
||||
"max": 0.99,
|
||||
"step": 0.01,
|
||||
"tooltip": "Low threshold for Canny edge detection; ignored if skip_processing is True"
|
||||
},
|
||||
),
|
||||
"canny_high_threshold": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.4,
|
||||
"min": 0.01,
|
||||
"max": 0.99,
|
||||
"step": 0.01,
|
||||
"tooltip": "High threshold for Canny edge detection; ignored if skip_processing is True"
|
||||
},
|
||||
),
|
||||
"skip_preprocessing": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.",
|
||||
},
|
||||
),
|
||||
"guidance": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 30,
|
||||
"min": 1,
|
||||
"max": 100,
|
||||
"tooltip": "Guidance strength for the image generation process"
|
||||
},
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 50,
|
||||
"min": 15,
|
||||
"max": 50,
|
||||
"tooltip": "Number of steps for the image generation process"
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
control_image: torch.Tensor,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
canny_low_threshold: float,
|
||||
canny_high_threshold: float,
|
||||
skip_preprocessing: bool,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
control_image = convert_image_to_base64(control_image[:, :, :, :3])
|
||||
preprocessed_image = None
|
||||
|
||||
# scale canny threshold between 0-500, to match BFL's API
|
||||
def scale_value(value: float, min_val=0, max_val=500):
|
||||
return min_val + value * (max_val - min_val)
|
||||
canny_low_threshold = int(round(scale_value(canny_low_threshold)))
|
||||
canny_high_threshold = int(round(scale_value(canny_high_threshold)))
|
||||
|
||||
|
||||
if skip_preprocessing:
|
||||
preprocessed_image = control_image
|
||||
control_image = None
|
||||
canny_low_threshold = None
|
||||
canny_high_threshold = None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-canny/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxCannyImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxCannyImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
control_image=control_image,
|
||||
canny_low_threshold=canny_low_threshold,
|
||||
canny_high_threshold=canny_high_threshold,
|
||||
preprocessed_image=preprocessed_image,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
class FluxProDepthNode(ComfyNodeABC):
|
||||
"""
|
||||
Generate image using a control image (depth).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"control_image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
"skip_preprocessing": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.",
|
||||
},
|
||||
),
|
||||
"guidance": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 15,
|
||||
"min": 1,
|
||||
"max": 100,
|
||||
"tooltip": "Guidance strength for the image generation process"
|
||||
},
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 50,
|
||||
"min": 15,
|
||||
"max": 50,
|
||||
"tooltip": "Number of steps for the image generation process"
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
control_image: torch.Tensor,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
skip_preprocessing: bool,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
control_image = convert_image_to_base64(control_image[:,:,:,:3])
|
||||
preprocessed_image = None
|
||||
|
||||
if skip_preprocessing:
|
||||
preprocessed_image = control_image
|
||||
control_image = None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-depth/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxDepthImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxDepthImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
control_image=control_image,
|
||||
preprocessed_image=preprocessed_image,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"FluxProUltraImageNode": FluxProUltraImageNode,
|
||||
# "FluxProImageNode": FluxProImageNode,
|
||||
"FluxProExpandNode": FluxProExpandNode,
|
||||
"FluxProFillNode": FluxProFillNode,
|
||||
"FluxProCannyNode": FluxProCannyNode,
|
||||
"FluxProDepthNode": FluxProDepthNode,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image",
|
||||
# "FluxProImageNode": "Flux 1.1 [pro] Image",
|
||||
"FluxProExpandNode": "Flux.1 Expand Image",
|
||||
"FluxProFillNode": "Flux.1 Fill Image",
|
||||
"FluxProCannyNode": "Flux.1 Canny Control Image",
|
||||
"FluxProDepthNode": "Flux.1 Depth Control Image",
|
||||
}
|
||||
801
comfy_api_nodes/nodes_ideogram.py
Normal file
801
comfy_api_nodes/nodes_ideogram.py
Normal file
@@ -0,0 +1,801 @@
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from inspect import cleandoc
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import io
|
||||
import torch
|
||||
from comfy_api_nodes.apis import (
|
||||
IdeogramGenerateRequest,
|
||||
IdeogramGenerateResponse,
|
||||
ImageRequest,
|
||||
IdeogramV3Request,
|
||||
IdeogramV3EditRequest,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_bytesio,
|
||||
bytesio_to_image_tensor,
|
||||
resize_mask_to_image,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
V1_V1_RES_MAP = {
|
||||
"Auto":"AUTO",
|
||||
"512 x 1536":"RESOLUTION_512_1536",
|
||||
"576 x 1408":"RESOLUTION_576_1408",
|
||||
"576 x 1472":"RESOLUTION_576_1472",
|
||||
"576 x 1536":"RESOLUTION_576_1536",
|
||||
"640 x 1024":"RESOLUTION_640_1024",
|
||||
"640 x 1344":"RESOLUTION_640_1344",
|
||||
"640 x 1408":"RESOLUTION_640_1408",
|
||||
"640 x 1472":"RESOLUTION_640_1472",
|
||||
"640 x 1536":"RESOLUTION_640_1536",
|
||||
"704 x 1152":"RESOLUTION_704_1152",
|
||||
"704 x 1216":"RESOLUTION_704_1216",
|
||||
"704 x 1280":"RESOLUTION_704_1280",
|
||||
"704 x 1344":"RESOLUTION_704_1344",
|
||||
"704 x 1408":"RESOLUTION_704_1408",
|
||||
"704 x 1472":"RESOLUTION_704_1472",
|
||||
"720 x 1280":"RESOLUTION_720_1280",
|
||||
"736 x 1312":"RESOLUTION_736_1312",
|
||||
"768 x 1024":"RESOLUTION_768_1024",
|
||||
"768 x 1088":"RESOLUTION_768_1088",
|
||||
"768 x 1152":"RESOLUTION_768_1152",
|
||||
"768 x 1216":"RESOLUTION_768_1216",
|
||||
"768 x 1232":"RESOLUTION_768_1232",
|
||||
"768 x 1280":"RESOLUTION_768_1280",
|
||||
"768 x 1344":"RESOLUTION_768_1344",
|
||||
"832 x 960":"RESOLUTION_832_960",
|
||||
"832 x 1024":"RESOLUTION_832_1024",
|
||||
"832 x 1088":"RESOLUTION_832_1088",
|
||||
"832 x 1152":"RESOLUTION_832_1152",
|
||||
"832 x 1216":"RESOLUTION_832_1216",
|
||||
"832 x 1248":"RESOLUTION_832_1248",
|
||||
"864 x 1152":"RESOLUTION_864_1152",
|
||||
"896 x 960":"RESOLUTION_896_960",
|
||||
"896 x 1024":"RESOLUTION_896_1024",
|
||||
"896 x 1088":"RESOLUTION_896_1088",
|
||||
"896 x 1120":"RESOLUTION_896_1120",
|
||||
"896 x 1152":"RESOLUTION_896_1152",
|
||||
"960 x 832":"RESOLUTION_960_832",
|
||||
"960 x 896":"RESOLUTION_960_896",
|
||||
"960 x 1024":"RESOLUTION_960_1024",
|
||||
"960 x 1088":"RESOLUTION_960_1088",
|
||||
"1024 x 640":"RESOLUTION_1024_640",
|
||||
"1024 x 768":"RESOLUTION_1024_768",
|
||||
"1024 x 832":"RESOLUTION_1024_832",
|
||||
"1024 x 896":"RESOLUTION_1024_896",
|
||||
"1024 x 960":"RESOLUTION_1024_960",
|
||||
"1024 x 1024":"RESOLUTION_1024_1024",
|
||||
"1088 x 768":"RESOLUTION_1088_768",
|
||||
"1088 x 832":"RESOLUTION_1088_832",
|
||||
"1088 x 896":"RESOLUTION_1088_896",
|
||||
"1088 x 960":"RESOLUTION_1088_960",
|
||||
"1120 x 896":"RESOLUTION_1120_896",
|
||||
"1152 x 704":"RESOLUTION_1152_704",
|
||||
"1152 x 768":"RESOLUTION_1152_768",
|
||||
"1152 x 832":"RESOLUTION_1152_832",
|
||||
"1152 x 864":"RESOLUTION_1152_864",
|
||||
"1152 x 896":"RESOLUTION_1152_896",
|
||||
"1216 x 704":"RESOLUTION_1216_704",
|
||||
"1216 x 768":"RESOLUTION_1216_768",
|
||||
"1216 x 832":"RESOLUTION_1216_832",
|
||||
"1232 x 768":"RESOLUTION_1232_768",
|
||||
"1248 x 832":"RESOLUTION_1248_832",
|
||||
"1280 x 704":"RESOLUTION_1280_704",
|
||||
"1280 x 720":"RESOLUTION_1280_720",
|
||||
"1280 x 768":"RESOLUTION_1280_768",
|
||||
"1280 x 800":"RESOLUTION_1280_800",
|
||||
"1312 x 736":"RESOLUTION_1312_736",
|
||||
"1344 x 640":"RESOLUTION_1344_640",
|
||||
"1344 x 704":"RESOLUTION_1344_704",
|
||||
"1344 x 768":"RESOLUTION_1344_768",
|
||||
"1408 x 576":"RESOLUTION_1408_576",
|
||||
"1408 x 640":"RESOLUTION_1408_640",
|
||||
"1408 x 704":"RESOLUTION_1408_704",
|
||||
"1472 x 576":"RESOLUTION_1472_576",
|
||||
"1472 x 640":"RESOLUTION_1472_640",
|
||||
"1472 x 704":"RESOLUTION_1472_704",
|
||||
"1536 x 512":"RESOLUTION_1536_512",
|
||||
"1536 x 576":"RESOLUTION_1536_576",
|
||||
"1536 x 640":"RESOLUTION_1536_640",
|
||||
}
|
||||
|
||||
V1_V2_RATIO_MAP = {
|
||||
"1:1":"ASPECT_1_1",
|
||||
"4:3":"ASPECT_4_3",
|
||||
"3:4":"ASPECT_3_4",
|
||||
"16:9":"ASPECT_16_9",
|
||||
"9:16":"ASPECT_9_16",
|
||||
"2:1":"ASPECT_2_1",
|
||||
"1:2":"ASPECT_1_2",
|
||||
"3:2":"ASPECT_3_2",
|
||||
"2:3":"ASPECT_2_3",
|
||||
"4:5":"ASPECT_4_5",
|
||||
"5:4":"ASPECT_5_4",
|
||||
}
|
||||
|
||||
V3_RATIO_MAP = {
|
||||
"1:3":"1x3",
|
||||
"3:1":"3x1",
|
||||
"1:2":"1x2",
|
||||
"2:1":"2x1",
|
||||
"9:16":"9x16",
|
||||
"16:9":"16x9",
|
||||
"10:16":"10x16",
|
||||
"16:10":"16x10",
|
||||
"2:3":"2x3",
|
||||
"3:2":"3x2",
|
||||
"3:4":"3x4",
|
||||
"4:3":"4x3",
|
||||
"4:5":"4x5",
|
||||
"5:4":"5x4",
|
||||
"1:1":"1x1",
|
||||
}
|
||||
|
||||
V3_RESOLUTIONS= [
|
||||
"Auto",
|
||||
"512x1536",
|
||||
"576x1408",
|
||||
"576x1472",
|
||||
"576x1536",
|
||||
"640x1344",
|
||||
"640x1408",
|
||||
"640x1472",
|
||||
"640x1536",
|
||||
"704x1152",
|
||||
"704x1216",
|
||||
"704x1280",
|
||||
"704x1344",
|
||||
"704x1408",
|
||||
"704x1472",
|
||||
"736x1312",
|
||||
"768x1088",
|
||||
"768x1216",
|
||||
"768x1280",
|
||||
"768x1344",
|
||||
"800x1280",
|
||||
"832x960",
|
||||
"832x1024",
|
||||
"832x1088",
|
||||
"832x1152",
|
||||
"832x1216",
|
||||
"832x1248",
|
||||
"864x1152",
|
||||
"896x960",
|
||||
"896x1024",
|
||||
"896x1088",
|
||||
"896x1120",
|
||||
"896x1152",
|
||||
"960x832",
|
||||
"960x896",
|
||||
"960x1024",
|
||||
"960x1088",
|
||||
"1024x832",
|
||||
"1024x896",
|
||||
"1024x960",
|
||||
"1024x1024",
|
||||
"1088x768",
|
||||
"1088x832",
|
||||
"1088x896",
|
||||
"1088x960",
|
||||
"1120x896",
|
||||
"1152x704",
|
||||
"1152x832",
|
||||
"1152x864",
|
||||
"1152x896",
|
||||
"1216x704",
|
||||
"1216x768",
|
||||
"1216x832",
|
||||
"1248x832",
|
||||
"1280x704",
|
||||
"1280x768",
|
||||
"1280x800",
|
||||
"1312x736",
|
||||
"1344x640",
|
||||
"1344x704",
|
||||
"1344x768",
|
||||
"1408x576",
|
||||
"1408x640",
|
||||
"1408x704",
|
||||
"1472x576",
|
||||
"1472x640",
|
||||
"1472x704",
|
||||
"1536x512",
|
||||
"1536x576",
|
||||
"1536x640"
|
||||
]
|
||||
|
||||
def download_and_process_images(image_urls):
|
||||
"""Helper function to download and process multiple images from URLs"""
|
||||
|
||||
# Initialize list to store image tensors
|
||||
image_tensors = []
|
||||
|
||||
for image_url in image_urls:
|
||||
# Using functions from apinode_utils.py to handle downloading and processing
|
||||
image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO
|
||||
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
||||
image_tensors.append(img_tensor)
|
||||
|
||||
# Stack tensors to match (N, width, height, channels)
|
||||
if image_tensors:
|
||||
stacked_tensors = torch.cat(image_tensors, dim=0)
|
||||
else:
|
||||
raise Exception("No valid images were processed")
|
||||
|
||||
return stacked_tensors
|
||||
|
||||
|
||||
def display_image_urls_on_node(image_urls, node_id):
|
||||
if node_id and image_urls:
|
||||
if len(image_urls) == 1:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generated Image URL:\n{image_urls[0]}", node_id
|
||||
)
|
||||
else:
|
||||
urls_text = "Generated Image URLs:\n" + "\n".join(
|
||||
f"{i+1}. {url}" for i, url in enumerate(image_urls)
|
||||
)
|
||||
PromptServer.instance.send_progress_text(urls_text, node_id)
|
||||
|
||||
|
||||
class IdeogramV1(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using the Ideogram V1 model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"turbo": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
}
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V1_V2_RATIO_MAP.keys()),
|
||||
"default": "1:1",
|
||||
"tooltip": "The aspect ratio for image generation.",
|
||||
},
|
||||
),
|
||||
"magic_prompt_option": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "ON", "OFF"],
|
||||
"default": "AUTO",
|
||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"step": 1,
|
||||
"control_after_generate": True,
|
||||
"display": "number",
|
||||
},
|
||||
),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Description of what to exclude from the image",
|
||||
},
|
||||
),
|
||||
"num_images": (
|
||||
IO.INT,
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram/v1"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt,
|
||||
turbo=False,
|
||||
aspect_ratio="1:1",
|
||||
magic_prompt_option="AUTO",
|
||||
seed=0,
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Determine the model based on turbo setting
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
model = "V_1_TURBO" if turbo else "V_1"
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramGenerateRequest,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=IdeogramGenerateRequest(
|
||||
image_request=ImageRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
num_images=num_images,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
|
||||
magic_prompt_option=(
|
||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
||||
),
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
|
||||
|
||||
class IdeogramV2(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using the Ideogram V2 model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"turbo": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
}
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V1_V2_RATIO_MAP.keys()),
|
||||
"default": "1:1",
|
||||
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
||||
},
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V1_V1_RES_MAP.keys()),
|
||||
"default": "Auto",
|
||||
"tooltip": "The resolution for image generation. If not set to AUTO, this overrides the aspect_ratio setting.",
|
||||
},
|
||||
),
|
||||
"magic_prompt_option": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "ON", "OFF"],
|
||||
"default": "AUTO",
|
||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"step": 1,
|
||||
"control_after_generate": True,
|
||||
"display": "number",
|
||||
},
|
||||
),
|
||||
"style_type": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
||||
"default": "NONE",
|
||||
"tooltip": "Style type for generation (V2 only)",
|
||||
},
|
||||
),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Description of what to exclude from the image",
|
||||
},
|
||||
),
|
||||
"num_images": (
|
||||
IO.INT,
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
),
|
||||
#"color_palette": (
|
||||
# IO.STRING,
|
||||
# {
|
||||
# "multiline": False,
|
||||
# "default": "",
|
||||
# "tooltip": "Color palette preset name or hex colors with weights",
|
||||
# },
|
||||
#),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram/v2"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt,
|
||||
turbo=False,
|
||||
aspect_ratio="1:1",
|
||||
resolution="Auto",
|
||||
magic_prompt_option="AUTO",
|
||||
seed=0,
|
||||
style_type="NONE",
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
color_palette="",
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
resolution = V1_V1_RES_MAP.get(resolution, None)
|
||||
# Determine the model based on turbo setting
|
||||
model = "V_2_TURBO" if turbo else "V_2"
|
||||
|
||||
# Handle resolution vs aspect_ratio logic
|
||||
# If resolution is not AUTO, it overrides aspect_ratio
|
||||
final_resolution = None
|
||||
final_aspect_ratio = None
|
||||
|
||||
if resolution != "AUTO":
|
||||
final_resolution = resolution
|
||||
else:
|
||||
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramGenerateRequest,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=IdeogramGenerateRequest(
|
||||
image_request=ImageRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
num_images=num_images,
|
||||
seed=seed,
|
||||
aspect_ratio=final_aspect_ratio,
|
||||
resolution=final_resolution,
|
||||
magic_prompt_option=(
|
||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
||||
),
|
||||
style_type=style_type if style_type != "NONE" else None,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
color_palette=color_palette if color_palette else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
|
||||
class IdeogramV3(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation or editing",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image for image editing.",
|
||||
},
|
||||
),
|
||||
"mask": (
|
||||
IO.MASK,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||
},
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V3_RATIO_MAP.keys()),
|
||||
"default": "1:1",
|
||||
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
|
||||
},
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": V3_RESOLUTIONS,
|
||||
"default": "Auto",
|
||||
"tooltip": "The resolution for image generation. If not set to Auto, this overrides the aspect_ratio setting.",
|
||||
},
|
||||
),
|
||||
"magic_prompt_option": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "ON", "OFF"],
|
||||
"default": "AUTO",
|
||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"step": 1,
|
||||
"control_after_generate": True,
|
||||
"display": "number",
|
||||
},
|
||||
),
|
||||
"num_images": (
|
||||
IO.INT,
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
),
|
||||
"rendering_speed": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["BALANCED", "TURBO", "QUALITY"],
|
||||
"default": "BALANCED",
|
||||
"tooltip": "Controls the trade-off between generation speed and quality",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram/v3"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt,
|
||||
image=None,
|
||||
mask=None,
|
||||
resolution="Auto",
|
||||
aspect_ratio="1:1",
|
||||
magic_prompt_option="AUTO",
|
||||
seed=0,
|
||||
num_images=1,
|
||||
rendering_speed="BALANCED",
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Check if both image and mask are provided for editing mode
|
||||
if image is not None and mask is not None:
|
||||
# Edit mode
|
||||
path = "/proxy/ideogram/ideogram-v3/edit"
|
||||
|
||||
# Process image and mask
|
||||
input_tensor = image.squeeze().cpu()
|
||||
# Resize mask to match image dimension
|
||||
mask = resize_mask_to_image(mask, image, allow_gradient=False)
|
||||
# Invert mask, as Ideogram API will edit black areas instead of white areas (opposite of convention).
|
||||
mask = 1.0 - mask
|
||||
|
||||
# Validate mask dimensions match image
|
||||
if mask.shape[1:] != image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
|
||||
# Process image
|
||||
img_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr
|
||||
img_binary.name = "image.png"
|
||||
|
||||
# Process mask - white areas will be replaced
|
||||
mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_byte_arr = io.BytesIO()
|
||||
mask_img.save(mask_byte_arr, format="PNG")
|
||||
mask_byte_arr.seek(0)
|
||||
mask_binary = mask_byte_arr
|
||||
mask_binary.name = "mask.png"
|
||||
|
||||
# Create edit request
|
||||
edit_request = IdeogramV3EditRequest(
|
||||
prompt=prompt,
|
||||
rendering_speed=rendering_speed,
|
||||
)
|
||||
|
||||
# Add optional parameters
|
||||
if magic_prompt_option != "AUTO":
|
||||
edit_request.magic_prompt = magic_prompt_option
|
||||
if seed != 0:
|
||||
edit_request.seed = seed
|
||||
if num_images > 1:
|
||||
edit_request.num_images = num_images
|
||||
|
||||
# Execute the operation for edit mode
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramV3EditRequest,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=edit_request,
|
||||
files={
|
||||
"image": img_binary,
|
||||
"mask": mask_binary,
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
elif image is not None or mask is not None:
|
||||
# If only one of image or mask is provided, raise an error
|
||||
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
|
||||
else:
|
||||
# Generation mode
|
||||
path = "/proxy/ideogram/ideogram-v3/generate"
|
||||
|
||||
# Create generation request
|
||||
gen_request = IdeogramV3Request(
|
||||
prompt=prompt,
|
||||
rendering_speed=rendering_speed,
|
||||
)
|
||||
|
||||
# Handle resolution vs aspect ratio
|
||||
if resolution != "Auto":
|
||||
gen_request.resolution = resolution
|
||||
elif aspect_ratio != "1:1":
|
||||
v3_aspect = V3_RATIO_MAP.get(aspect_ratio)
|
||||
if v3_aspect:
|
||||
gen_request.aspect_ratio = v3_aspect
|
||||
|
||||
# Add optional parameters
|
||||
if magic_prompt_option != "AUTO":
|
||||
gen_request.magic_prompt = magic_prompt_option
|
||||
if seed != 0:
|
||||
gen_request.seed = seed
|
||||
if num_images > 1:
|
||||
gen_request.num_images = num_images
|
||||
|
||||
# Execute the operation for generation mode
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramV3Request,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=gen_request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Execute the operation and process response
|
||||
response = operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"IdeogramV1": IdeogramV1,
|
||||
"IdeogramV2": IdeogramV2,
|
||||
"IdeogramV3": IdeogramV3,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"IdeogramV1": "Ideogram V1",
|
||||
"IdeogramV2": "Ideogram V2",
|
||||
"IdeogramV3": "Ideogram V3",
|
||||
}
|
||||
1758
comfy_api_nodes/nodes_kling.py
Normal file
1758
comfy_api_nodes/nodes_kling.py
Normal file
File diff suppressed because it is too large
Load Diff
737
comfy_api_nodes/nodes_luma.py
Normal file
737
comfy_api_nodes/nodes_luma.py
Normal file
@@ -0,0 +1,737 @@
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis.luma_api import (
|
||||
LumaImageModel,
|
||||
LumaVideoModel,
|
||||
LumaVideoOutputResolution,
|
||||
LumaVideoModelOutputDuration,
|
||||
LumaAspectRatio,
|
||||
LumaState,
|
||||
LumaImageGenerationRequest,
|
||||
LumaGenerationRequest,
|
||||
LumaGeneration,
|
||||
LumaCharacterRef,
|
||||
LumaModifyImageRef,
|
||||
LumaImageIdentity,
|
||||
LumaReference,
|
||||
LumaReferenceChain,
|
||||
LumaImageReference,
|
||||
LumaKeyframes,
|
||||
LumaConceptChain,
|
||||
LumaIO,
|
||||
get_luma_concepts,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
upload_images_to_comfyapi,
|
||||
process_image_response,
|
||||
validate_string,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from io import BytesIO
|
||||
|
||||
LUMA_T2V_AVERAGE_DURATION = 105
|
||||
LUMA_I2V_AVERAGE_DURATION = 100
|
||||
|
||||
def image_result_url_extractor(response: LumaGeneration):
|
||||
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
|
||||
|
||||
def video_result_url_extractor(response: LumaGeneration):
|
||||
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
|
||||
|
||||
class LumaReferenceNode(ComfyNodeABC):
|
||||
"""
|
||||
Holds an image and weight for use with Luma Generate Image node.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (LumaIO.LUMA_REF,)
|
||||
RETURN_NAMES = ("luma_ref",)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "create_luma_reference"
|
||||
CATEGORY = "api node/image/Luma"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "Image to use as reference.",
|
||||
},
|
||||
),
|
||||
"weight": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
"tooltip": "Weight of image reference.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {"luma_ref": (LumaIO.LUMA_REF,)},
|
||||
}
|
||||
|
||||
def create_luma_reference(
|
||||
self, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
|
||||
):
|
||||
if luma_ref is not None:
|
||||
luma_ref = luma_ref.clone()
|
||||
else:
|
||||
luma_ref = LumaReferenceChain()
|
||||
luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
|
||||
return (luma_ref,)
|
||||
|
||||
|
||||
class LumaConceptsNode(ComfyNodeABC):
|
||||
"""
|
||||
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (LumaIO.LUMA_CONCEPTS,)
|
||||
RETURN_NAMES = ("luma_concepts",)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "create_concepts"
|
||||
CATEGORY = "api node/video/Luma"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"concept1": (get_luma_concepts(include_none=True),),
|
||||
"concept2": (get_luma_concepts(include_none=True),),
|
||||
"concept3": (get_luma_concepts(include_none=True),),
|
||||
"concept4": (get_luma_concepts(include_none=True),),
|
||||
},
|
||||
"optional": {
|
||||
"luma_concepts": (
|
||||
LumaIO.LUMA_CONCEPTS,
|
||||
{
|
||||
"tooltip": "Optional Camera Concepts to add to the ones chosen here."
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
def create_concepts(
|
||||
self,
|
||||
concept1: str,
|
||||
concept2: str,
|
||||
concept3: str,
|
||||
concept4: str,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
):
|
||||
chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
|
||||
if luma_concepts is not None:
|
||||
chain = luma_concepts.clone_and_merge(chain)
|
||||
return (chain,)
|
||||
|
||||
|
||||
class LumaImageGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Luma"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"model": ([model.value for model in LumaImageModel],),
|
||||
"aspect_ratio": (
|
||||
[ratio.value for ratio in LumaAspectRatio],
|
||||
{
|
||||
"default": LumaAspectRatio.ratio_16_9,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
},
|
||||
),
|
||||
"style_image_weight": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
"tooltip": "Weight of style image. Ignored if no style_image provided.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image_luma_ref": (
|
||||
LumaIO.LUMA_REF,
|
||||
{
|
||||
"tooltip": "Luma Reference node connection to influence generation with input images; up to 4 images can be considered."
|
||||
},
|
||||
),
|
||||
"style_image": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Style reference image; only 1 image will be used."},
|
||||
),
|
||||
"character_image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "Character reference images; can be a batch of multiple, up to 4 images can be considered."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
aspect_ratio: str,
|
||||
seed,
|
||||
style_image_weight: float,
|
||||
image_luma_ref: LumaReferenceChain = None,
|
||||
style_image: torch.Tensor = None,
|
||||
character_image: torch.Tensor = None,
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=True, min_length=3)
|
||||
# handle image_luma_ref
|
||||
api_image_ref = None
|
||||
if image_luma_ref is not None:
|
||||
api_image_ref = self._convert_luma_refs(
|
||||
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle style_luma_ref
|
||||
api_style_ref = None
|
||||
if style_image is not None:
|
||||
api_style_ref = self._convert_style_image(
|
||||
style_image, weight=style_image_weight, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle character_ref images
|
||||
character_ref = None
|
||||
if character_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
character_image, max_images=4, auth_kwargs=kwargs,
|
||||
)
|
||||
character_ref = LumaCharacterRef(
|
||||
identity0=LumaImageIdentity(images=download_urls)
|
||||
)
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations/image",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaImageGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
aspect_ratio=aspect_ratio,
|
||||
image_ref=api_image_ref,
|
||||
style_ref=api_style_ref,
|
||||
character_ref=character_ref,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=image_result_url_extractor,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
img_response = requests.get(response_poll.assets.image)
|
||||
img = process_image_response(img_response)
|
||||
return (img,)
|
||||
|
||||
def _convert_luma_refs(
|
||||
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
luma_urls = []
|
||||
ref_count = 0
|
||||
for ref in luma_ref.refs:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
ref.image, max_images=1, auth_kwargs=auth_kwargs
|
||||
)
|
||||
luma_urls.append(download_urls[0])
|
||||
ref_count += 1
|
||||
if ref_count >= max_refs:
|
||||
break
|
||||
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
||||
|
||||
def _convert_style_image(
|
||||
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
chain = LumaReferenceChain(
|
||||
first_ref=LumaReference(image=style_image, weight=weight)
|
||||
)
|
||||
return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||
|
||||
|
||||
class LumaImageModifyNode(ComfyNodeABC):
|
||||
"""
|
||||
Modifies images synchronously based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Luma"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"image_weight": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.1,
|
||||
"min": 0.0,
|
||||
"max": 0.98,
|
||||
"step": 0.01,
|
||||
"tooltip": "Weight of the image; the closer to 1.0, the less the image will be modified.",
|
||||
},
|
||||
),
|
||||
"model": ([model.value for model in LumaImageModel],),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
image_weight: float,
|
||||
seed,
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
# first, upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=kwargs,
|
||||
)
|
||||
image_url = download_urls[0]
|
||||
# next, make Luma call with download url provided
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations/image",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaImageGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
modify_image_ref=LumaModifyImageRef(
|
||||
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=image_result_url_extractor,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
img_response = requests.get(response_poll.assets.image)
|
||||
img = process_image_response(img_response)
|
||||
return (img,)
|
||||
|
||||
|
||||
class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos synchronously based on prompt and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/video/Luma"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the video generation",
|
||||
},
|
||||
),
|
||||
"model": ([model.value for model in LumaVideoModel],),
|
||||
"aspect_ratio": (
|
||||
[ratio.value for ratio in LumaAspectRatio],
|
||||
{
|
||||
"default": LumaAspectRatio.ratio_16_9,
|
||||
},
|
||||
),
|
||||
"resolution": (
|
||||
[resolution.value for resolution in LumaVideoOutputResolution],
|
||||
{
|
||||
"default": LumaVideoOutputResolution.res_540p,
|
||||
},
|
||||
),
|
||||
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
|
||||
"loop": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"luma_concepts": (
|
||||
LumaIO.LUMA_CONCEPTS,
|
||||
{
|
||||
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
duration: str,
|
||||
loop: bool,
|
||||
seed,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False, min_length=3)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
resolution=resolution,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=duration,
|
||||
loop=loop,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=video_result_url_extractor,
|
||||
node_id=unique_id,
|
||||
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.assets.video)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
|
||||
|
||||
class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos synchronously based on prompt, input images, and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/video/Luma"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the video generation",
|
||||
},
|
||||
),
|
||||
"model": ([model.value for model in LumaVideoModel],),
|
||||
# "aspect_ratio": ([ratio.value for ratio in LumaAspectRatio], {
|
||||
# "default": LumaAspectRatio.ratio_16_9,
|
||||
# }),
|
||||
"resolution": (
|
||||
[resolution.value for resolution in LumaVideoOutputResolution],
|
||||
{
|
||||
"default": LumaVideoOutputResolution.res_540p,
|
||||
},
|
||||
),
|
||||
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
|
||||
"loop": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"first_image": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "First frame of generated video."},
|
||||
),
|
||||
"last_image": (IO.IMAGE, {"tooltip": "Last frame of generated video."}),
|
||||
"luma_concepts": (
|
||||
LumaIO.LUMA_CONCEPTS,
|
||||
{
|
||||
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
resolution: str,
|
||||
duration: str,
|
||||
loop: bool,
|
||||
seed,
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
if first_image is None and last_image is None:
|
||||
raise Exception(
|
||||
"At least one of first_image and last_image requires an input."
|
||||
)
|
||||
keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
loop=loop,
|
||||
keyframes=keyframes,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=video_result_url_extractor,
|
||||
node_id=unique_id,
|
||||
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.assets.video)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
|
||||
def _convert_to_keyframes(
|
||||
self,
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
):
|
||||
if first_image is None and last_image is None:
|
||||
return None
|
||||
frame0 = None
|
||||
frame1 = None
|
||||
if first_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
first_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
||||
if last_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
last_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
||||
return LumaKeyframes(frame0=frame0, frame1=frame1)
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LumaImageNode": LumaImageGenerationNode,
|
||||
"LumaImageModifyNode": LumaImageModifyNode,
|
||||
"LumaVideoNode": LumaTextToVideoGenerationNode,
|
||||
"LumaImageToVideoNode": LumaImageToVideoGenerationNode,
|
||||
"LumaReferenceNode": LumaReferenceNode,
|
||||
"LumaConceptsNode": LumaConceptsNode,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LumaImageNode": "Luma Text to Image",
|
||||
"LumaImageModifyNode": "Luma Image to Image",
|
||||
"LumaVideoNode": "Luma Text to Video",
|
||||
"LumaImageToVideoNode": "Luma Image to Video",
|
||||
"LumaReferenceNode": "Luma Reference",
|
||||
"LumaConceptsNode": "Luma Concepts",
|
||||
}
|
||||
332
comfy_api_nodes/nodes_minimax.py
Normal file
332
comfy_api_nodes/nodes_minimax.py
Normal file
@@ -0,0 +1,332 @@
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
MinimaxVideoGenerationRequest,
|
||||
MinimaxVideoGenerationResponse,
|
||||
MinimaxFileRetrieveResponse,
|
||||
MinimaxTaskResultResponse,
|
||||
SubjectReferenceItem,
|
||||
Model
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_bytesio,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
I2V_AVERAGE_DURATION = 114
|
||||
T2V_AVERAGE_DURATION = 234
|
||||
|
||||
class MinimaxTextToVideoNode:
|
||||
"""
|
||||
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
AVERAGE_DURATION = T2V_AVERAGE_DURATION
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt_text": (
|
||||
"STRING",
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt to guide the video generation",
|
||||
},
|
||||
),
|
||||
"model": (
|
||||
[
|
||||
"T2V-01",
|
||||
"T2V-01-Director",
|
||||
],
|
||||
{
|
||||
"default": "T2V-01",
|
||||
"tooltip": "Model to use for video generation",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
DESCRIPTION = "Generates videos from prompts using MiniMax's API"
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/MiniMax"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_video(
|
||||
self,
|
||||
prompt_text,
|
||||
seed=0,
|
||||
model="T2V-01",
|
||||
image: torch.Tensor=None, # used for ImageToVideo
|
||||
subject: torch.Tensor=None, # used for SubjectToVideo
|
||||
unique_id: Union[str, None]=None,
|
||||
**kwargs,
|
||||
):
|
||||
'''
|
||||
Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments.
|
||||
'''
|
||||
if image is None:
|
||||
validate_string(prompt_text, field_name="prompt_text")
|
||||
# upload image, if passed in
|
||||
image_url = None
|
||||
if image is not None:
|
||||
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0]
|
||||
|
||||
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
||||
subject_reference = None
|
||||
if subject is not None:
|
||||
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0]
|
||||
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
||||
|
||||
|
||||
video_generate_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/video_generation",
|
||||
method=HttpMethod.POST,
|
||||
request_model=MinimaxVideoGenerationRequest,
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
),
|
||||
request=MinimaxVideoGenerationRequest(
|
||||
model=Model(model),
|
||||
prompt=prompt_text,
|
||||
callback_url=None,
|
||||
first_frame_image=image_url,
|
||||
subject_reference=subject_reference,
|
||||
prompt_optimizer=None,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response = video_generate_operation.execute()
|
||||
|
||||
task_id = response.task_id
|
||||
if not task_id:
|
||||
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||
|
||||
video_generate_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/query/video_generation",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
query_params={"task_id": task_id},
|
||||
),
|
||||
completed_statuses=["Success"],
|
||||
failed_statuses=["Fail"],
|
||||
status_extractor=lambda x: x.status.value,
|
||||
estimated_duration=self.AVERAGE_DURATION,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_result = video_generate_operation.execute()
|
||||
|
||||
file_id = task_result.file_id
|
||||
if file_id is None:
|
||||
raise Exception("Request was not successful. Missing file ID.")
|
||||
file_retrieve_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/files/retrieve",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
query_params={"file_id": int(file_id)},
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
file_result = file_retrieve_operation.execute()
|
||||
|
||||
file_url = file_result.file.download_url
|
||||
if file_url is None:
|
||||
raise Exception(
|
||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||
)
|
||||
logging.info(f"Generated video URL: {file_url}")
|
||||
if unique_id:
|
||||
if hasattr(file_result.file, "backup_download_url"):
|
||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||
else:
|
||||
message = f"Result URL: {file_url}"
|
||||
PromptServer.instance.send_progress_text(message, unique_id)
|
||||
|
||||
video_io = download_url_to_bytesio(file_url)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {file_url}"
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return (VideoFromFile(video_io),)
|
||||
|
||||
|
||||
class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
|
||||
"""
|
||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
AVERAGE_DURATION = I2V_AVERAGE_DURATION
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "Image to use as first frame of video generation"
|
||||
},
|
||||
),
|
||||
"prompt_text": (
|
||||
"STRING",
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt to guide the video generation",
|
||||
},
|
||||
),
|
||||
"model": (
|
||||
[
|
||||
"I2V-01-Director",
|
||||
"I2V-01",
|
||||
"I2V-01-live",
|
||||
],
|
||||
{
|
||||
"default": "I2V-01",
|
||||
"tooltip": "Model to use for video generation",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/MiniMax"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
|
||||
class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
|
||||
"""
|
||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
AVERAGE_DURATION = T2V_AVERAGE_DURATION
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"subject": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "Image of subject to reference video generation"
|
||||
},
|
||||
),
|
||||
"prompt_text": (
|
||||
"STRING",
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt to guide the video generation",
|
||||
},
|
||||
),
|
||||
"model": (
|
||||
[
|
||||
"S2V-01",
|
||||
],
|
||||
{
|
||||
"default": "S2V-01",
|
||||
"tooltip": "Model to use for video generation",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/MiniMax"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"MinimaxTextToVideoNode": MinimaxTextToVideoNode,
|
||||
"MinimaxImageToVideoNode": MinimaxImageToVideoNode,
|
||||
# "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"MinimaxTextToVideoNode": "MiniMax Text to Video",
|
||||
"MinimaxImageToVideoNode": "MiniMax Image to Video",
|
||||
"MinimaxSubjectToVideoNode": "MiniMax Subject to Video",
|
||||
}
|
||||
502
comfy_api_nodes/nodes_openai.py
Normal file
502
comfy_api_nodes/nodes_openai.py
Normal file
@@ -0,0 +1,502 @@
|
||||
import io
|
||||
from inspect import cleandoc
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
|
||||
|
||||
from comfy_api_nodes.apis import (
|
||||
OpenAIImageGenerationRequest,
|
||||
OpenAIImageEditRequest,
|
||||
OpenAIImageGenerationResponse,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
validate_and_cast_response,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
class OpenAIDalle2(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for DALL·E",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2**31 - 1,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"control_after_generate": True,
|
||||
"tooltip": "not implemented yet in backend",
|
||||
},
|
||||
),
|
||||
"size": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["256x256", "512x512", "1024x1024"],
|
||||
"default": "1024x1024",
|
||||
"tooltip": "Image size",
|
||||
},
|
||||
),
|
||||
"n": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "How many images to generate",
|
||||
},
|
||||
),
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image for image editing.",
|
||||
},
|
||||
),
|
||||
"mask": (
|
||||
IO.MASK,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/OpenAI"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt,
|
||||
seed=0,
|
||||
image=None,
|
||||
mask=None,
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
unique_id=None,
|
||||
**kwargs
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "dall-e-2"
|
||||
path = "/proxy/openai/images/generations"
|
||||
content_type = "application/json"
|
||||
request_class = OpenAIImageGenerationRequest
|
||||
img_binary = None
|
||||
|
||||
if image is not None and mask is not None:
|
||||
path = "/proxy/openai/images/edits"
|
||||
content_type = "multipart/form-data"
|
||||
request_class = OpenAIImageEditRequest
|
||||
|
||||
input_tensor = image.squeeze().cpu()
|
||||
height, width, channels = input_tensor.shape
|
||||
rgba_tensor = torch.ones(height, width, 4, device="cpu")
|
||||
rgba_tensor[:, :, :channels] = input_tensor
|
||||
|
||||
if mask.shape[1:] != image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
rgba_tensor[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||
|
||||
rgba_tensor = downscale_image_tensor(rgba_tensor.unsqueeze(0)).squeeze()
|
||||
|
||||
image_np = (rgba_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr # .getvalue()
|
||||
img_binary.name = "image.png"
|
||||
elif image is not None or mask is not None:
|
||||
raise Exception("Dall-E 2 image editing requires an image AND a mask")
|
||||
|
||||
# Build the operation
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=request_class,
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
),
|
||||
request=request_class(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
size=size,
|
||||
seed=seed,
|
||||
),
|
||||
files=(
|
||||
{
|
||||
"image": img_binary,
|
||||
}
|
||||
if img_binary
|
||||
else None
|
||||
),
|
||||
content_type=content_type,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
class OpenAIDalle3(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for DALL·E",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2**31 - 1,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"control_after_generate": True,
|
||||
"tooltip": "not implemented yet in backend",
|
||||
},
|
||||
),
|
||||
"quality": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["standard", "hd"],
|
||||
"default": "standard",
|
||||
"tooltip": "Image quality",
|
||||
},
|
||||
),
|
||||
"style": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["natural", "vivid"],
|
||||
"default": "natural",
|
||||
"tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.",
|
||||
},
|
||||
),
|
||||
"size": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["1024x1024", "1024x1792", "1792x1024"],
|
||||
"default": "1024x1024",
|
||||
"tooltip": "Image size",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/OpenAI"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt,
|
||||
seed=0,
|
||||
style="natural",
|
||||
quality="standard",
|
||||
size="1024x1024",
|
||||
unique_id=None,
|
||||
**kwargs
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "dall-e-3"
|
||||
|
||||
# build the operation
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/openai/images/generations",
|
||||
method=HttpMethod.POST,
|
||||
request_model=OpenAIImageGenerationRequest,
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
),
|
||||
request=OpenAIImageGenerationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
size=size,
|
||||
style=style,
|
||||
seed=seed,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
class OpenAIGPTImage1(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for GPT Image 1",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2**31 - 1,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"control_after_generate": True,
|
||||
"tooltip": "not implemented yet in backend",
|
||||
},
|
||||
),
|
||||
"quality": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["low", "medium", "high"],
|
||||
"default": "low",
|
||||
"tooltip": "Image quality, affects cost and generation time.",
|
||||
},
|
||||
),
|
||||
"background": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["opaque", "transparent"],
|
||||
"default": "opaque",
|
||||
"tooltip": "Return image with or without background",
|
||||
},
|
||||
),
|
||||
"size": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["auto", "1024x1024", "1024x1536", "1536x1024"],
|
||||
"default": "auto",
|
||||
"tooltip": "Image size",
|
||||
},
|
||||
),
|
||||
"n": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "How many images to generate",
|
||||
},
|
||||
),
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image for image editing.",
|
||||
},
|
||||
),
|
||||
"mask": (
|
||||
IO.MASK,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/OpenAI"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt,
|
||||
seed=0,
|
||||
quality="low",
|
||||
background="opaque",
|
||||
image=None,
|
||||
mask=None,
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
unique_id=None,
|
||||
**kwargs
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "gpt-image-1"
|
||||
path = "/proxy/openai/images/generations"
|
||||
content_type="application/json"
|
||||
request_class = OpenAIImageGenerationRequest
|
||||
img_binaries = []
|
||||
mask_binary = None
|
||||
files = []
|
||||
|
||||
if image is not None:
|
||||
path = "/proxy/openai/images/edits"
|
||||
request_class = OpenAIImageEditRequest
|
||||
content_type ="multipart/form-data"
|
||||
|
||||
batch_size = image.shape[0]
|
||||
|
||||
for i in range(batch_size):
|
||||
single_image = image[i : i + 1]
|
||||
scaled_image = downscale_image_tensor(single_image).squeeze()
|
||||
|
||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr
|
||||
img_binary.name = f"image_{i}.png"
|
||||
|
||||
img_binaries.append(img_binary)
|
||||
if batch_size == 1:
|
||||
files.append(("image", img_binary))
|
||||
else:
|
||||
files.append(("image[]", img_binary))
|
||||
|
||||
if mask is not None:
|
||||
if image is None:
|
||||
raise Exception("Cannot use a mask without an input image")
|
||||
if image.shape[0] != 1:
|
||||
raise Exception("Cannot use a mask with multiple image")
|
||||
if mask.shape[1:] != image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
batch, height, width = mask.shape
|
||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||
|
||||
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze()
|
||||
|
||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_img_byte_arr = io.BytesIO()
|
||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||
mask_img_byte_arr.seek(0)
|
||||
mask_binary = mask_img_byte_arr
|
||||
mask_binary.name = "mask.png"
|
||||
files.append(("mask", mask_binary))
|
||||
|
||||
# Build the operation
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=request_class,
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
),
|
||||
request=request_class(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
seed=seed,
|
||||
size=size,
|
||||
),
|
||||
files=files if files else None,
|
||||
content_type=content_type,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"OpenAIDalle2": OpenAIDalle2,
|
||||
"OpenAIDalle3": OpenAIDalle3,
|
||||
"OpenAIGPTImage1": OpenAIGPTImage1,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"OpenAIDalle2": "OpenAI DALL·E 2",
|
||||
"OpenAIDalle3": "OpenAI DALL·E 3",
|
||||
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
||||
}
|
||||
779
comfy_api_nodes/nodes_pika.py
Normal file
779
comfy_api_nodes/nodes_pika.py
Normal file
@@ -0,0 +1,779 @@
|
||||
"""
|
||||
Pika x ComfyUI API Nodes
|
||||
|
||||
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from typing import Optional, TypeVar
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
from comfy_api_nodes.apis import (
|
||||
PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||
PikaGenerateResponse,
|
||||
PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||
PikaVideoResponse,
|
||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
IngredientsMode,
|
||||
PikaDurationEnum,
|
||||
PikaResolutionEnum,
|
||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||
Pikaffect,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
tensor_to_bytesio,
|
||||
download_url_to_video_output,
|
||||
)
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
from comfy_api.input_impl.video_types import VideoInput, VideoContainer, VideoCodec
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
||||
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
||||
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
|
||||
|
||||
PIKA_API_VERSION = "2.2"
|
||||
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
|
||||
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
|
||||
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
|
||||
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
|
||||
|
||||
PATH_VIDEO_GET = "/proxy/pika/videos"
|
||||
|
||||
|
||||
class PikaApiError(Exception):
|
||||
"""Exception for Pika API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def is_valid_video_response(response: PikaVideoResponse) -> bool:
|
||||
"""Check if the video response is valid."""
|
||||
return hasattr(response, "url") and response.url is not None
|
||||
|
||||
|
||||
def is_valid_initial_response(response: PikaGenerateResponse) -> bool:
|
||||
"""Check if the initial response is valid."""
|
||||
return hasattr(response, "video_id") and response.video_id is not None
|
||||
|
||||
|
||||
class PikaNodeBase(ComfyNodeABC):
|
||||
"""Base class for Pika nodes."""
|
||||
|
||||
@classmethod
|
||||
def get_base_inputs_types(
|
||||
cls, request_model
|
||||
) -> dict[str, tuple[IO, InputTypeOptions]]:
|
||||
"""Get the base required inputs types common to all Pika nodes."""
|
||||
return {
|
||||
"prompt_text": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
request_model,
|
||||
"promptText",
|
||||
multiline=True,
|
||||
),
|
||||
"negative_prompt": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
request_model,
|
||||
"negativePrompt",
|
||||
multiline=True,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
request_model,
|
||||
"seed",
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
),
|
||||
"resolution": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
request_model,
|
||||
"resolution",
|
||||
enum_type=PikaResolutionEnum,
|
||||
),
|
||||
"duration": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
request_model,
|
||||
"duration",
|
||||
enum_type=PikaDurationEnum,
|
||||
),
|
||||
}
|
||||
|
||||
CATEGORY = "api node/video/Pika"
|
||||
API_NODE = True
|
||||
FUNCTION = "api_call"
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
|
||||
def poll_for_task_status(
|
||||
self,
|
||||
task_id: str,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> PikaGenerateResponse:
|
||||
polling_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_GET}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PikaVideoResponse,
|
||||
),
|
||||
completed_statuses=[
|
||||
"finished",
|
||||
],
|
||||
failed_statuses=["failed", "cancelled"],
|
||||
status_extractor=lambda response: (
|
||||
response.status.value if response.status else None
|
||||
),
|
||||
progress_extractor=lambda response: (
|
||||
response.progress if hasattr(response, "progress") else None
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=lambda response: (
|
||||
response.url if hasattr(response, "url") else None
|
||||
),
|
||||
node_id=node_id,
|
||||
estimated_duration=60
|
||||
)
|
||||
return polling_operation.execute()
|
||||
|
||||
def execute_task(
|
||||
self,
|
||||
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> tuple[VideoFromFile]:
|
||||
"""Executes the initial operation then polls for the task status until it is completed.
|
||||
|
||||
Args:
|
||||
initial_operation: The initial operation to execute.
|
||||
auth_kwargs: The authentication token(s) to use for the API call.
|
||||
|
||||
Returns:
|
||||
A tuple containing the video file as a VIDEO output.
|
||||
"""
|
||||
initial_response = initial_operation.execute()
|
||||
if not is_valid_initial_response(initial_response):
|
||||
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
|
||||
logging.error(error_msg)
|
||||
raise PikaApiError(error_msg)
|
||||
|
||||
task_id = initial_response.video_id
|
||||
final_response = self.poll_for_task_status(task_id, auth_kwargs)
|
||||
if not is_valid_video_response(final_response):
|
||||
error_msg = (
|
||||
f"Pika task {task_id} succeeded but no video data found in response."
|
||||
)
|
||||
logging.error(error_msg)
|
||||
raise PikaApiError(error_msg)
|
||||
|
||||
video_url = str(final_response.url)
|
||||
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||
|
||||
return (download_url_to_video_output(video_url),)
|
||||
|
||||
|
||||
class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
"""Pika 2.2 Image to Video Node."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "The image to convert to video"},
|
||||
),
|
||||
**cls.get_base_inputs_types(PikaBodyGenerate22I2vGenerate22I2vPost),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert image to BytesIO
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
|
||||
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
||||
|
||||
# Prepare non-file data
|
||||
pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
"""Pika Text2Video v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
**cls.get_base_inputs_types(PikaBodyGenerate22T2vGenerate22T2vPost),
|
||||
"aspect_ratio": model_field_to_node_input(
|
||||
IO.FLOAT,
|
||||
PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||
"aspectRatio",
|
||||
step=0.001,
|
||||
min=0.4,
|
||||
max=2.5,
|
||||
default=1.7777777777777777,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
aspect_ratio: float,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_TEXT_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaScenesV2_2(PikaNodeBase):
|
||||
"""PikaScenes v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
image_ingredient_input = (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Image that will be used as ingredient to create a video."},
|
||||
)
|
||||
return {
|
||||
"required": {
|
||||
**cls.get_base_inputs_types(
|
||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
),
|
||||
"ingredients_mode": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
"ingredientsMode",
|
||||
enum_type=IngredientsMode,
|
||||
default="creative",
|
||||
),
|
||||
"aspect_ratio": model_field_to_node_input(
|
||||
IO.FLOAT,
|
||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
"aspectRatio",
|
||||
step=0.001,
|
||||
min=0.4,
|
||||
max=2.5,
|
||||
default=1.7777777777777777,
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image_ingredient_1": image_ingredient_input,
|
||||
"image_ingredient_2": image_ingredient_input,
|
||||
"image_ingredient_3": image_ingredient_input,
|
||||
"image_ingredient_4": image_ingredient_input,
|
||||
"image_ingredient_5": image_ingredient_input,
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
ingredients_mode: str,
|
||||
aspect_ratio: float,
|
||||
unique_id: str,
|
||||
image_ingredient_1: Optional[torch.Tensor] = None,
|
||||
image_ingredient_2: Optional[torch.Tensor] = None,
|
||||
image_ingredient_3: Optional[torch.Tensor] = None,
|
||||
image_ingredient_4: Optional[torch.Tensor] = None,
|
||||
image_ingredient_5: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert all passed images to BytesIO
|
||||
all_image_bytes_io = []
|
||||
for image in [
|
||||
image_ingredient_1,
|
||||
image_ingredient_2,
|
||||
image_ingredient_3,
|
||||
image_ingredient_4,
|
||||
image_ingredient_5,
|
||||
]:
|
||||
if image is not None:
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
all_image_bytes_io.append(image_bytes_io)
|
||||
|
||||
pika_files = [
|
||||
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
|
||||
for i, image_bytes_io in enumerate(all_image_bytes_io)
|
||||
]
|
||||
|
||||
pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
||||
ingredientsMode=ingredients_mode,
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKASCENES,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikAdditionsNode(PikaNodeBase):
|
||||
"""Pika Pikadditions Node. Add an image into a video."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"video": (IO.VIDEO, {"tooltip": "The video to add an image to."}),
|
||||
"image": (IO.IMAGE, {"tooltip": "The image to add to the video."}),
|
||||
"prompt_text": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
"promptText",
|
||||
multiline=True,
|
||||
),
|
||||
"negative_prompt": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
"negativePrompt",
|
||||
multiline=True,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
"seed",
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you’d like to add to create a seamlessly integrated result."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
image: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert video to BytesIO
|
||||
video_bytes_io = io.BytesIO()
|
||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
# Convert image to BytesIO
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
|
||||
pika_files = [
|
||||
("video", ("video.mp4", video_bytes_io, "video/mp4")),
|
||||
("image", ("image.png", image_bytes_io, "image/png")),
|
||||
]
|
||||
|
||||
# Prepare non-file data
|
||||
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKADDITIONS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaSwapsNode(PikaNodeBase):
|
||||
"""Pika Pikaswaps Node."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"video": (IO.VIDEO, {"tooltip": "The video to swap an object in."}),
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "The image used to replace the masked object in the video."
|
||||
},
|
||||
),
|
||||
"mask": (
|
||||
IO.MASK,
|
||||
{"tooltip": "Use the mask to define areas in the video to replace"},
|
||||
),
|
||||
"prompt_text": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
"promptText",
|
||||
multiline=True,
|
||||
),
|
||||
"negative_prompt": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
"negativePrompt",
|
||||
multiline=True,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
"seed",
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates."
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
image: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert video to BytesIO
|
||||
video_bytes_io = io.BytesIO()
|
||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
# Convert mask to binary mask with three channels
|
||||
mask = torch.round(mask)
|
||||
mask = mask.repeat(1, 3, 1, 1)
|
||||
|
||||
# Convert 3-channel binary mask to BytesIO
|
||||
mask_bytes_io = io.BytesIO()
|
||||
mask_bytes_io.write(mask.numpy().astype(np.uint8))
|
||||
mask_bytes_io.seek(0)
|
||||
|
||||
# Convert image to BytesIO
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
|
||||
pika_files = [
|
||||
("video", ("video.mp4", video_bytes_io, "video/mp4")),
|
||||
("image", ("image.png", image_bytes_io, "image/png")),
|
||||
("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")),
|
||||
]
|
||||
|
||||
# Prepare non-file data
|
||||
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKADDITIONS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaffectsNode(PikaNodeBase):
|
||||
"""Pika Pikaffects Node."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "The reference image to apply the Pikaffect to."},
|
||||
),
|
||||
"pikaffect": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
"pikaffect",
|
||||
enum_type=Pikaffect,
|
||||
default="Cake-ify",
|
||||
),
|
||||
"prompt_text": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
"promptText",
|
||||
multiline=True,
|
||||
),
|
||||
"negative_prompt": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
"negativePrompt",
|
||||
multiline=True,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
"seed",
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
pikaffect: str,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKAFFECTS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||
pikaffect=pikaffect,
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
),
|
||||
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
"""PikaFrames v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image_start": (IO.IMAGE, {"tooltip": "The first image to combine."}),
|
||||
"image_end": (IO.IMAGE, {"tooltip": "The last image to combine."}),
|
||||
**cls.get_base_inputs_types(
|
||||
PikaBodyGenerate22KeyframeGenerate22PikaframesPost
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image_start: torch.Tensor,
|
||||
image_end: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
|
||||
pika_files = [
|
||||
(
|
||||
"keyFrames",
|
||||
("image_start.png", tensor_to_bytesio(image_start), "image/png"),
|
||||
),
|
||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||
]
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKAFRAMES,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
),
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PikaImageToVideoNode2_2": PikaImageToVideoV2_2,
|
||||
"PikaTextToVideoNode2_2": PikaTextToVideoNodeV2_2,
|
||||
"PikaScenesV2_2": PikaScenesV2_2,
|
||||
"Pikadditions": PikAdditionsNode,
|
||||
"Pikaswaps": PikaSwapsNode,
|
||||
"Pikaffects": PikaffectsNode,
|
||||
"PikaStartEndFrameNode2_2": PikaStartEndFrameNode2_2,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PikaImageToVideoNode2_2": "Pika Image to Video",
|
||||
"PikaTextToVideoNode2_2": "Pika Text to Video",
|
||||
"PikaScenesV2_2": "Pika Scenes (Video Image Composition)",
|
||||
"Pikadditions": "Pikadditions (Video Object Insertion)",
|
||||
"Pikaswaps": "Pika Swaps (Video Object Replacement)",
|
||||
"Pikaffects": "Pikaffects (Video Effects)",
|
||||
"PikaStartEndFrameNode2_2": "Pika Start and End Frame to Video",
|
||||
}
|
||||
525
comfy_api_nodes/nodes_pixverse.py
Normal file
525
comfy_api_nodes/nodes_pixverse.py
Normal file
@@ -0,0 +1,525 @@
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
from comfy_api_nodes.apis.pixverse_api import (
|
||||
PixverseTextVideoRequest,
|
||||
PixverseImageVideoRequest,
|
||||
PixverseTransitionVideoRequest,
|
||||
PixverseImageUploadResponse,
|
||||
PixverseVideoResponse,
|
||||
PixverseGenerationStatusResponse,
|
||||
PixverseAspectRatio,
|
||||
PixverseQuality,
|
||||
PixverseDuration,
|
||||
PixverseMotionMode,
|
||||
PixverseStatus,
|
||||
PixverseIO,
|
||||
pixverse_templates,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
)
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
|
||||
import torch
|
||||
import requests
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
AVERAGE_DURATION_T2V = 32
|
||||
AVERAGE_DURATION_I2V = 30
|
||||
AVERAGE_DURATION_T2T = 52
|
||||
|
||||
|
||||
def get_video_url_from_response(
|
||||
response: PixverseGenerationStatusResponse,
|
||||
) -> Optional[str]:
|
||||
if response.Resp is None or response.Resp.url is None:
|
||||
return None
|
||||
return str(response.Resp.url)
|
||||
|
||||
|
||||
def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
# first, upload image to Pixverse and get image id to use in actual generation call
|
||||
files = {"image": tensor_to_bytesio(image)}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/image/upload",
|
||||
method=HttpMethod.POST,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseImageUploadResponse,
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_upload: PixverseImageUploadResponse = operation.execute()
|
||||
|
||||
if response_upload.Resp is None:
|
||||
raise Exception(
|
||||
f"PixVerse image upload request failed: '{response_upload.ErrMsg}'"
|
||||
)
|
||||
|
||||
return response_upload.Resp.img_id
|
||||
|
||||
|
||||
class PixverseTemplateNode:
|
||||
"""
|
||||
Select template for PixVerse Video generation.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (PixverseIO.TEMPLATE,)
|
||||
RETURN_NAMES = ("pixverse_template",)
|
||||
FUNCTION = "create_template"
|
||||
CATEGORY = "api node/video/PixVerse"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"template": (list(pixverse_templates.keys()),),
|
||||
}
|
||||
}
|
||||
|
||||
def create_template(self, template: str):
|
||||
template_id = pixverse_templates.get(template, None)
|
||||
if template_id is None:
|
||||
raise Exception(f"Template '{template}' is not recognized.")
|
||||
# just return the integer
|
||||
return (template_id,)
|
||||
|
||||
|
||||
class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos based on prompt and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/video/PixVerse"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the video generation",
|
||||
},
|
||||
),
|
||||
"aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],),
|
||||
"quality": (
|
||||
[resolution.value for resolution in PixverseQuality],
|
||||
{
|
||||
"default": PixverseQuality.res_540p,
|
||||
},
|
||||
),
|
||||
"duration_seconds": ([dur.value for dur in PixverseDuration],),
|
||||
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed for video generation.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "An optional text description of undesired elements on an image.",
|
||||
},
|
||||
),
|
||||
"pixverse_template": (
|
||||
PixverseIO.TEMPLATE,
|
||||
{
|
||||
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
quality: str,
|
||||
duration_seconds: int,
|
||||
motion_mode: str,
|
||||
seed,
|
||||
negative_prompt: str = None,
|
||||
pixverse_template: int = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
if quality == PixverseQuality.res_1080p:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
duration_seconds = PixverseDuration.dur_5
|
||||
elif duration_seconds != PixverseDuration.dur_5:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/video/text/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=PixverseTextVideoRequest,
|
||||
response_model=PixverseVideoResponse,
|
||||
),
|
||||
request=PixverseTextVideoRequest(
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
quality=quality,
|
||||
duration=duration_seconds,
|
||||
motion_mode=motion_mode,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
template_id=pixverse_template,
|
||||
seed=seed,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[
|
||||
PixverseStatus.contents_moderation,
|
||||
PixverseStatus.failed,
|
||||
PixverseStatus.deleted,
|
||||
],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
|
||||
|
||||
class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos based on prompt and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/video/PixVerse"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the video generation",
|
||||
},
|
||||
),
|
||||
"quality": (
|
||||
[resolution.value for resolution in PixverseQuality],
|
||||
{
|
||||
"default": PixverseQuality.res_540p,
|
||||
},
|
||||
),
|
||||
"duration_seconds": ([dur.value for dur in PixverseDuration],),
|
||||
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed for video generation.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "An optional text description of undesired elements on an image.",
|
||||
},
|
||||
),
|
||||
"pixverse_template": (
|
||||
PixverseIO.TEMPLATE,
|
||||
{
|
||||
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
quality: str,
|
||||
duration_seconds: int,
|
||||
motion_mode: str,
|
||||
seed,
|
||||
negative_prompt: str = None,
|
||||
pixverse_template: int = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
if quality == PixverseQuality.res_1080p:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
duration_seconds = PixverseDuration.dur_5
|
||||
elif duration_seconds != PixverseDuration.dur_5:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/video/img/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=PixverseImageVideoRequest,
|
||||
response_model=PixverseVideoResponse,
|
||||
),
|
||||
request=PixverseImageVideoRequest(
|
||||
img_id=img_id,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
duration=duration_seconds,
|
||||
motion_mode=motion_mode,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
template_id=pixverse_template,
|
||||
seed=seed,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[
|
||||
PixverseStatus.contents_moderation,
|
||||
PixverseStatus.failed,
|
||||
PixverseStatus.deleted,
|
||||
],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
|
||||
|
||||
class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos based on prompt and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/video/PixVerse"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"first_frame": (IO.IMAGE,),
|
||||
"last_frame": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the video generation",
|
||||
},
|
||||
),
|
||||
"quality": (
|
||||
[resolution.value for resolution in PixverseQuality],
|
||||
{
|
||||
"default": PixverseQuality.res_540p,
|
||||
},
|
||||
),
|
||||
"duration_seconds": ([dur.value for dur in PixverseDuration],),
|
||||
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed for video generation.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "An optional text description of undesired elements on an image.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
first_frame: torch.Tensor,
|
||||
last_frame: torch.Tensor,
|
||||
prompt: str,
|
||||
quality: str,
|
||||
duration_seconds: int,
|
||||
motion_mode: str,
|
||||
seed,
|
||||
negative_prompt: str = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
|
||||
last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
if quality == PixverseQuality.res_1080p:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
duration_seconds = PixverseDuration.dur_5
|
||||
elif duration_seconds != PixverseDuration.dur_5:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/video/transition/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=PixverseTransitionVideoRequest,
|
||||
response_model=PixverseVideoResponse,
|
||||
),
|
||||
request=PixverseTransitionVideoRequest(
|
||||
first_frame_img=first_frame_id,
|
||||
last_frame_img=last_frame_id,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
duration=duration_seconds,
|
||||
motion_mode=motion_mode,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
seed=seed,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[
|
||||
PixverseStatus.contents_moderation,
|
||||
PixverseStatus.failed,
|
||||
PixverseStatus.deleted,
|
||||
],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PixverseTextToVideoNode": PixverseTextToVideoNode,
|
||||
"PixverseImageToVideoNode": PixverseImageToVideoNode,
|
||||
"PixverseTransitionVideoNode": PixverseTransitionVideoNode,
|
||||
"PixverseTemplateNode": PixverseTemplateNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PixverseTextToVideoNode": "PixVerse Text to Video",
|
||||
"PixverseImageToVideoNode": "PixVerse Image to Video",
|
||||
"PixverseTransitionVideoNode": "PixVerse Transition Video",
|
||||
"PixverseTemplateNode": "PixVerse Template",
|
||||
}
|
||||
1138
comfy_api_nodes/nodes_recraft.py
Normal file
1138
comfy_api_nodes/nodes_recraft.py
Normal file
File diff suppressed because it is too large
Load Diff
614
comfy_api_nodes/nodes_stability.py
Normal file
614
comfy_api_nodes/nodes_stability.py
Normal file
@@ -0,0 +1,614 @@
|
||||
from inspect import cleandoc
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api_nodes.apis.stability_api import (
|
||||
StabilityUpscaleConservativeRequest,
|
||||
StabilityUpscaleCreativeRequest,
|
||||
StabilityAsyncResponse,
|
||||
StabilityResultsGetResponse,
|
||||
StabilityStable3_5Request,
|
||||
StabilityStableUltraRequest,
|
||||
StabilityStableUltraResponse,
|
||||
StabilityAspectRatio,
|
||||
Stability_SD3_5_Model,
|
||||
Stability_SD3_5_GenerationMode,
|
||||
get_stability_style_presets,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
bytesio_to_image_tensor,
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
import torch
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class StabilityPollStatus(str, Enum):
|
||||
finished = "finished"
|
||||
in_progress = "in_progress"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
def get_async_dummy_status(x: StabilityResultsGetResponse):
|
||||
if x.name is not None or x.errors is not None:
|
||||
return StabilityPollStatus.failed
|
||||
elif x.finish_reason is not None:
|
||||
return StabilityPollStatus.finished
|
||||
return StabilityPollStatus.in_progress
|
||||
|
||||
|
||||
class StabilityStableImageUltraNode:
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Stability AI"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
|
||||
"What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
|
||||
"elements, colors, and subjects will lead to better results. " +
|
||||
"To control the weight of a given word use the format `(word:weight)`," +
|
||||
"where `word` is the word you'd like to control the weight of and `weight`" +
|
||||
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
|
||||
"would convey a sky that was blue and green, but more green than blue."
|
||||
},
|
||||
),
|
||||
"aspect_ratio": ([x.value for x in StabilityAspectRatio],
|
||||
{
|
||||
"default": StabilityAspectRatio.ratio_1_1,
|
||||
"tooltip": "Aspect ratio of generated image.",
|
||||
},
|
||||
),
|
||||
"style_preset": (get_stability_style_presets(),
|
||||
{
|
||||
"tooltip": "Optional desired style of generated image.",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 4294967294,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image": (IO.IMAGE,),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "A blurb of text describing what you do not wish to see in the output image. This is an advanced feature."
|
||||
},
|
||||
),
|
||||
"image_denoise": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.5,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
"tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
|
||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
if image is not None:
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
|
||||
else:
|
||||
image_denoise = None
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
if style_preset == "None":
|
||||
style_preset = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/generate/ultra",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityStableUltraRequest,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=StabilityStableUltraRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
seed=seed,
|
||||
strength=image_denoise,
|
||||
style_preset=style_preset,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return (returned_image,)
|
||||
|
||||
|
||||
class StabilityStableImageSD_3_5Node:
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Stability AI"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
|
||||
},
|
||||
),
|
||||
"model": ([x.value for x in Stability_SD3_5_Model],),
|
||||
"aspect_ratio": ([x.value for x in StabilityAspectRatio],
|
||||
{
|
||||
"default": StabilityAspectRatio.ratio_1_1,
|
||||
"tooltip": "Aspect ratio of generated image.",
|
||||
},
|
||||
),
|
||||
"style_preset": (get_stability_style_presets(),
|
||||
{
|
||||
"tooltip": "Optional desired style of generated image.",
|
||||
},
|
||||
),
|
||||
"cfg_scale": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 4.0,
|
||||
"min": 1.0,
|
||||
"max": 10.0,
|
||||
"step": 0.1,
|
||||
"tooltip": "How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 4294967294,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image": (IO.IMAGE,),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
|
||||
},
|
||||
),
|
||||
"image_denoise": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.5,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
"tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
|
||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
mode = Stability_SD3_5_GenerationMode.text_to_image
|
||||
if image is not None:
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
|
||||
mode = Stability_SD3_5_GenerationMode.image_to_image
|
||||
aspect_ratio = None
|
||||
else:
|
||||
image_denoise = None
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
if style_preset == "None":
|
||||
style_preset = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/generate/sd3",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityStable3_5Request,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=StabilityStable3_5Request(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
seed=seed,
|
||||
strength=image_denoise,
|
||||
style_preset=style_preset,
|
||||
cfg_scale=cfg_scale,
|
||||
model=model,
|
||||
mode=mode,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return (returned_image,)
|
||||
|
||||
|
||||
class StabilityUpscaleConservativeNode:
|
||||
"""
|
||||
Upscale image with minimal alterations to 4K resolution.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Stability AI"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
|
||||
},
|
||||
),
|
||||
"creativity": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.35,
|
||||
"min": 0.2,
|
||||
"max": 0.5,
|
||||
"step": 0.01,
|
||||
"tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 4294967294,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityUpscaleConservativeRequest,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=StabilityUpscaleConservativeRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
creativity=round(creativity,2),
|
||||
seed=seed,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return (returned_image,)
|
||||
|
||||
|
||||
class StabilityUpscaleCreativeNode:
|
||||
"""
|
||||
Upscale image with minimal alterations to 4K resolution.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Stability AI"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
|
||||
},
|
||||
),
|
||||
"creativity": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.3,
|
||||
"min": 0.1,
|
||||
"max": 0.5,
|
||||
"step": 0.01,
|
||||
"tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
},
|
||||
),
|
||||
"style_preset": (get_stability_style_presets(),
|
||||
{
|
||||
"tooltip": "Optional desired style of generated image.",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 4294967294,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
if style_preset == "None":
|
||||
style_preset = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/upscale/creative",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityUpscaleCreativeRequest,
|
||||
response_model=StabilityAsyncResponse,
|
||||
),
|
||||
request=StabilityUpscaleCreativeRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
creativity=round(creativity,2),
|
||||
style_preset=style_preset,
|
||||
seed=seed,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/stability/v2beta/results/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=StabilityResultsGetResponse,
|
||||
),
|
||||
poll_interval=3,
|
||||
completed_statuses=[StabilityPollStatus.finished],
|
||||
failed_statuses=[StabilityPollStatus.failed],
|
||||
status_extractor=lambda x: get_async_dummy_status(x),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll: StabilityResultsGetResponse = operation.execute()
|
||||
|
||||
if response_poll.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_poll.result)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return (returned_image,)
|
||||
|
||||
|
||||
class StabilityUpscaleFastNode:
|
||||
"""
|
||||
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Stability AI"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor,
|
||||
**kwargs):
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/upscale/fast",
|
||||
method=HttpMethod.POST,
|
||||
request_model=EmptyRequest,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return (returned_image,)
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StabilityStableImageUltraNode": StabilityStableImageUltraNode,
|
||||
"StabilityStableImageSD_3_5Node": StabilityStableImageSD_3_5Node,
|
||||
"StabilityUpscaleConservativeNode": StabilityUpscaleConservativeNode,
|
||||
"StabilityUpscaleCreativeNode": StabilityUpscaleCreativeNode,
|
||||
"StabilityUpscaleFastNode": StabilityUpscaleFastNode,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StabilityStableImageUltraNode": "Stability AI Stable Image Ultra",
|
||||
"StabilityStableImageSD_3_5Node": "Stability AI Stable Diffusion 3.5 Image",
|
||||
"StabilityUpscaleConservativeNode": "Stability AI Upscale Conservative",
|
||||
"StabilityUpscaleCreativeNode": "Stability AI Upscale Creative",
|
||||
"StabilityUpscaleFastNode": "Stability AI Upscale Fast",
|
||||
}
|
||||
308
comfy_api_nodes/nodes_veo2.py
Normal file
308
comfy_api_nodes/nodes_veo2.py
Normal file
@@ -0,0 +1,308 @@
|
||||
import io
|
||||
import logging
|
||||
import base64
|
||||
import requests
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
Veo2GenVidRequest,
|
||||
Veo2GenVidResponse,
|
||||
Veo2GenVidPollRequest,
|
||||
Veo2GenVidPollResponse
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
tensor_to_base64_string
|
||||
)
|
||||
|
||||
AVERAGE_DURATION_VIDEO_GEN = 32
|
||||
|
||||
def convert_image_to_base64(image: torch.Tensor):
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
|
||||
return tensor_to_base64_string(scaled_image)
|
||||
|
||||
|
||||
def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]:
|
||||
if (
|
||||
poll_response.response
|
||||
and hasattr(poll_response.response, "videos")
|
||||
and poll_response.response.videos
|
||||
and len(poll_response.response.videos) > 0
|
||||
):
|
||||
video = poll_response.response.videos[0]
|
||||
else:
|
||||
return None
|
||||
if hasattr(video, "gcsUri") and video.gcsUri:
|
||||
return str(video.gcsUri)
|
||||
return None
|
||||
|
||||
|
||||
class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos from text prompts using Google's Veo API.
|
||||
|
||||
This node can create videos from text descriptions and optional image inputs,
|
||||
with control over parameters like aspect ratio, duration, and more.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text description of the video",
|
||||
},
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["16:9", "9:16"],
|
||||
"default": "16:9",
|
||||
"tooltip": "Aspect ratio of the output video",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Negative text prompt to guide what to avoid in the video",
|
||||
},
|
||||
),
|
||||
"duration_seconds": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 5,
|
||||
"min": 5,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "Duration of the output video in seconds",
|
||||
},
|
||||
),
|
||||
"enhance_prompt": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": True,
|
||||
"tooltip": "Whether to enhance the prompt with AI assistance",
|
||||
}
|
||||
),
|
||||
"person_generation": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["ALLOW", "BLOCK"],
|
||||
"default": "ALLOW",
|
||||
"tooltip": "Whether to allow generating people in the video",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFF,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed for video generation (0 for random)",
|
||||
},
|
||||
),
|
||||
"image": (IO.IMAGE, {
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image to guide video generation",
|
||||
}),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/Veo"
|
||||
DESCRIPTION = "Generates videos from text prompts using Google's Veo API"
|
||||
API_NODE = True
|
||||
|
||||
def generate_video(
|
||||
self,
|
||||
prompt,
|
||||
aspect_ratio="16:9",
|
||||
negative_prompt="",
|
||||
duration_seconds=5,
|
||||
enhance_prompt=True,
|
||||
person_generation="ALLOW",
|
||||
seed=0,
|
||||
image=None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Prepare the instances for the request
|
||||
instances = []
|
||||
|
||||
instance = {
|
||||
"prompt": prompt
|
||||
}
|
||||
|
||||
# Add image if provided
|
||||
if image is not None:
|
||||
image_base64 = convert_image_to_base64(image)
|
||||
if image_base64:
|
||||
instance["image"] = {
|
||||
"bytesBase64Encoded": image_base64,
|
||||
"mimeType": "image/png"
|
||||
}
|
||||
|
||||
instances.append(instance)
|
||||
|
||||
# Create parameters dictionary
|
||||
parameters = {
|
||||
"aspectRatio": aspect_ratio,
|
||||
"personGeneration": person_generation,
|
||||
"durationSeconds": duration_seconds,
|
||||
"enhancePrompt": enhance_prompt,
|
||||
}
|
||||
|
||||
# Add optional parameters if provided
|
||||
if negative_prompt:
|
||||
parameters["negativePrompt"] = negative_prompt
|
||||
if seed > 0:
|
||||
parameters["seed"] = seed
|
||||
|
||||
# Initial request to start video generation
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/veo/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Veo2GenVidRequest,
|
||||
response_model=Veo2GenVidResponse
|
||||
),
|
||||
request=Veo2GenVidRequest(
|
||||
instances=instances,
|
||||
parameters=parameters
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
operation_name = initial_response.name
|
||||
|
||||
logging.info(f"Veo generation started with operation name: {operation_name}")
|
||||
|
||||
# Define status extractor function
|
||||
def status_extractor(response):
|
||||
# Only return "completed" if the operation is done, regardless of success or failure
|
||||
# We'll check for errors after polling completes
|
||||
return "completed" if response.done else "pending"
|
||||
|
||||
# Define progress extractor function
|
||||
def progress_extractor(response):
|
||||
# Could be enhanced if the API provides progress information
|
||||
return None
|
||||
|
||||
# Define the polling operation
|
||||
poll_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/veo/poll",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Veo2GenVidPollRequest,
|
||||
response_model=Veo2GenVidPollResponse
|
||||
),
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=[], # No failed statuses, we'll handle errors after polling
|
||||
status_extractor=status_extractor,
|
||||
progress_extractor=progress_extractor,
|
||||
request=Veo2GenVidPollRequest(
|
||||
operationName=operation_name
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
poll_interval=5.0,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
node_id=unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||
)
|
||||
|
||||
# Execute the polling operation
|
||||
poll_response = poll_operation.execute()
|
||||
|
||||
# Now check for errors in the final response
|
||||
# Check for error in poll response
|
||||
if hasattr(poll_response, 'error') and poll_response.error:
|
||||
error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})"
|
||||
logging.error(error_message)
|
||||
raise Exception(error_message)
|
||||
|
||||
# Check for RAI filtered content
|
||||
if (hasattr(poll_response.response, 'raiMediaFilteredCount') and
|
||||
poll_response.response.raiMediaFilteredCount > 0):
|
||||
|
||||
# Extract reason message if available
|
||||
if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and
|
||||
poll_response.response.raiMediaFilteredReasons):
|
||||
reason = poll_response.response.raiMediaFilteredReasons[0]
|
||||
error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
||||
else:
|
||||
error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
||||
|
||||
logging.error(error_message)
|
||||
raise Exception(error_message)
|
||||
|
||||
# Extract video data
|
||||
video_data = None
|
||||
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
|
||||
video = poll_response.response.videos[0]
|
||||
|
||||
# Check if video is provided as base64 or URL
|
||||
if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded:
|
||||
# Decode base64 string to bytes
|
||||
video_data = base64.b64decode(video.bytesBase64Encoded)
|
||||
elif hasattr(video, 'gcsUri') and video.gcsUri:
|
||||
# Download from URL
|
||||
video_url = video.gcsUri
|
||||
video_response = requests.get(video_url)
|
||||
video_data = video_response.content
|
||||
else:
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
else:
|
||||
raise Exception("Video generation completed but no video was returned")
|
||||
|
||||
if not video_data:
|
||||
raise Exception("No video data was returned")
|
||||
|
||||
logging.info("Video generation completed successfully")
|
||||
|
||||
# Convert video data to BytesIO object
|
||||
video_io = io.BytesIO(video_data)
|
||||
|
||||
# Return VideoFromFile object
|
||||
return (VideoFromFile(video_io),)
|
||||
|
||||
|
||||
# Register the node
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": VeoVideoGenerationNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": "Google Veo2 Video Generation",
|
||||
}
|
||||
10
comfy_api_nodes/redocly-dev.yaml
Normal file
10
comfy_api_nodes/redocly-dev.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes.
|
||||
# This is used for development purposes to generate stubs for unreleased API endpoints.
|
||||
apis:
|
||||
filter:
|
||||
root: openapi.yaml
|
||||
decorators:
|
||||
filter-in:
|
||||
property: tags
|
||||
value: ['API Nodes']
|
||||
matchStrategy: all
|
||||
10
comfy_api_nodes/redocly.yaml
Normal file
10
comfy_api_nodes/redocly.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes.
|
||||
|
||||
apis:
|
||||
filter:
|
||||
root: openapi.yaml
|
||||
decorators:
|
||||
filter-in:
|
||||
property: tags
|
||||
value: ['API Nodes', 'Released']
|
||||
matchStrategy: all
|
||||
0
comfy_api_nodes/util/__init__.py
Normal file
0
comfy_api_nodes/util/__init__.py
Normal file
100
comfy_api_nodes/util/validation_utils.py
Normal file
100
comfy_api_nodes/util/validation_utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
|
||||
|
||||
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
|
||||
if len(image.shape) == 4:
|
||||
return image.shape[1], image.shape[2]
|
||||
elif len(image.shape) == 3:
|
||||
return image.shape[0], image.shape[1]
|
||||
else:
|
||||
raise ValueError("Invalid image tensor shape.")
|
||||
|
||||
|
||||
def validate_image_dimensions(
|
||||
image: torch.Tensor,
|
||||
min_width: Optional[int] = None,
|
||||
max_width: Optional[int] = None,
|
||||
min_height: Optional[int] = None,
|
||||
max_height: Optional[int] = None,
|
||||
):
|
||||
height, width = get_image_dimensions(image)
|
||||
|
||||
if min_width is not None and width < min_width:
|
||||
raise ValueError(f"Image width must be at least {min_width}px, got {width}px")
|
||||
if max_width is not None and width > max_width:
|
||||
raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
|
||||
if min_height is not None and height < min_height:
|
||||
raise ValueError(
|
||||
f"Image height must be at least {min_height}px, got {height}px"
|
||||
)
|
||||
if max_height is not None and height > max_height:
|
||||
raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
|
||||
|
||||
|
||||
def validate_image_aspect_ratio(
|
||||
image: torch.Tensor,
|
||||
min_aspect_ratio: Optional[float] = None,
|
||||
max_aspect_ratio: Optional[float] = None,
|
||||
):
|
||||
width, height = get_image_dimensions(image)
|
||||
aspect_ratio = width / height
|
||||
|
||||
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
|
||||
raise ValueError(
|
||||
f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}"
|
||||
)
|
||||
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
|
||||
raise ValueError(
|
||||
f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}"
|
||||
)
|
||||
|
||||
|
||||
def validate_video_dimensions(
|
||||
video: VideoInput,
|
||||
min_width: Optional[int] = None,
|
||||
max_width: Optional[int] = None,
|
||||
min_height: Optional[int] = None,
|
||||
max_height: Optional[int] = None,
|
||||
):
|
||||
try:
|
||||
width, height = video.get_dimensions()
|
||||
except Exception as e:
|
||||
logging.error("Error getting dimensions of video: %s", e)
|
||||
return
|
||||
|
||||
if min_width is not None and width < min_width:
|
||||
raise ValueError(f"Video width must be at least {min_width}px, got {width}px")
|
||||
if max_width is not None and width > max_width:
|
||||
raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
|
||||
if min_height is not None and height < min_height:
|
||||
raise ValueError(
|
||||
f"Video height must be at least {min_height}px, got {height}px"
|
||||
)
|
||||
if max_height is not None and height > max_height:
|
||||
raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
|
||||
|
||||
|
||||
def validate_video_duration(
|
||||
video: VideoInput,
|
||||
min_duration: Optional[float] = None,
|
||||
max_duration: Optional[float] = None,
|
||||
):
|
||||
try:
|
||||
duration = video.get_duration()
|
||||
except Exception as e:
|
||||
logging.error("Error getting duration of video: %s", e)
|
||||
return
|
||||
|
||||
epsilon = 0.0001
|
||||
if min_duration is not None and min_duration - epsilon > duration:
|
||||
raise ValueError(
|
||||
f"Video duration must be at least {min_duration}s, got {duration}s"
|
||||
)
|
||||
if max_duration is not None and duration > max_duration + epsilon:
|
||||
raise ValueError(
|
||||
f"Video duration must be at most {max_duration}s, got {duration}s"
|
||||
)
|
||||
49
comfy_extras/nodes_ace.py
Normal file
49
comfy_extras/nodes_ace.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import node_helpers
|
||||
|
||||
class TextEncodeAceStepAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def encode(self, clip, tags, lyrics, lyrics_strength):
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
||||
return (conditioning, )
|
||||
|
||||
|
||||
class EmptyAceStepLatentAudio:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def generate(self, seconds, batch_size):
|
||||
length = int(seconds * 44100 / 512 / 8)
|
||||
latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
|
||||
return ({"samples": latent, "type": "audio"}, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TextEncodeAceStepAudio": TextEncodeAceStepAudio,
|
||||
"EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
|
||||
}
|
||||
76
comfy_extras/nodes_apg.py
Normal file
76
comfy_extras/nodes_apg.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
|
||||
def project(v0, v1):
|
||||
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
|
||||
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
return v0_parallel, v0_orthogonal
|
||||
|
||||
class APG:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
|
||||
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
|
||||
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
|
||||
def patch(self, model, eta, norm_threshold, momentum):
|
||||
running_avg = 0
|
||||
prev_sigma = None
|
||||
|
||||
def pre_cfg_function(args):
|
||||
nonlocal running_avg, prev_sigma
|
||||
|
||||
if len(args["conds_out"]) == 1: return args["conds_out"]
|
||||
|
||||
cond = args["conds_out"][0]
|
||||
uncond = args["conds_out"][1]
|
||||
sigma = args["sigma"][0]
|
||||
cond_scale = args["cond_scale"]
|
||||
|
||||
if prev_sigma is not None and sigma > prev_sigma:
|
||||
running_avg = 0
|
||||
prev_sigma = sigma
|
||||
|
||||
guidance = cond - uncond
|
||||
|
||||
if momentum != 0:
|
||||
if not torch.is_tensor(running_avg):
|
||||
running_avg = guidance
|
||||
else:
|
||||
running_avg = momentum * running_avg + guidance
|
||||
guidance = running_avg
|
||||
|
||||
if norm_threshold > 0:
|
||||
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
||||
scale = torch.minimum(
|
||||
torch.ones_like(guidance_norm),
|
||||
norm_threshold / guidance_norm
|
||||
)
|
||||
guidance = guidance * scale
|
||||
|
||||
guidance_parallel, guidance_orthogonal = project(guidance, cond)
|
||||
modified_guidance = guidance_orthogonal + eta * guidance_parallel
|
||||
|
||||
modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale
|
||||
|
||||
return [modified_cond, uncond] + args["conds_out"][2:]
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||
return (m,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"APG": APG,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"APG": "Adaptive Projected Guidance",
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import av
|
||||
import torchaudio
|
||||
import torch
|
||||
import comfy.model_management
|
||||
@@ -7,7 +8,6 @@ import folder_paths
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import struct
|
||||
import random
|
||||
import hashlib
|
||||
import node_helpers
|
||||
@@ -90,60 +90,118 @@ class VAEDecodeAudio:
|
||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||
|
||||
|
||||
def create_vorbis_comment_block(comment_dict, last_block):
|
||||
vendor_string = b'ComfyUI'
|
||||
vendor_length = len(vendor_string)
|
||||
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
|
||||
|
||||
comments = []
|
||||
for key, value in comment_dict.items():
|
||||
comment = f"{key}={value}".encode('utf-8')
|
||||
comments.append(struct.pack('<I', len(comment)) + comment)
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results: list[FileLocator] = []
|
||||
|
||||
user_comment_list_length = len(comments)
|
||||
user_comments = b''.join(comments)
|
||||
# Prepare metadata dictionary
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
comment_data = struct.pack('<I', vendor_length) + vendor_string + struct.pack('<I', user_comment_list_length) + user_comments
|
||||
if last_block:
|
||||
id = b'\x84'
|
||||
else:
|
||||
id = b'\x04'
|
||||
comment_block = id + struct.pack('>I', len(comment_data))[1:] + comment_data
|
||||
# Opus supported sample rates
|
||||
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
|
||||
|
||||
return comment_block
|
||||
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
||||
output_path = os.path.join(full_output_folder, file)
|
||||
|
||||
def insert_or_replace_vorbis_comment(flac_io, comment_dict):
|
||||
if len(comment_dict) == 0:
|
||||
return flac_io
|
||||
# Use original sample rate initially
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
flac_io.seek(4)
|
||||
# Handle Opus sample rate requirements
|
||||
if format == "opus":
|
||||
if sample_rate > 48000:
|
||||
sample_rate = 48000
|
||||
elif sample_rate not in OPUS_RATES:
|
||||
# Find the next highest supported rate
|
||||
for rate in sorted(OPUS_RATES):
|
||||
if rate > sample_rate:
|
||||
sample_rate = rate
|
||||
break
|
||||
if sample_rate not in OPUS_RATES: # Fallback if still not supported
|
||||
sample_rate = 48000
|
||||
|
||||
blocks = []
|
||||
last_block = False
|
||||
# Resample if necessary
|
||||
if sample_rate != audio["sample_rate"]:
|
||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||
|
||||
while not last_block:
|
||||
header = flac_io.read(4)
|
||||
last_block = (header[0] & 0x80) != 0
|
||||
block_type = header[0] & 0x7F
|
||||
block_length = struct.unpack('>I', b'\x00' + header[1:])[0]
|
||||
block_data = flac_io.read(block_length)
|
||||
# Create in-memory WAV buffer
|
||||
wav_buffer = io.BytesIO()
|
||||
torchaudio.save(wav_buffer, waveform, sample_rate, format="WAV")
|
||||
wav_buffer.seek(0) # Rewind for reading
|
||||
|
||||
if block_type == 4 or block_type == 1:
|
||||
pass
|
||||
else:
|
||||
header = bytes([(header[0] & (~0x80))]) + header[1:]
|
||||
blocks.append(header + block_data)
|
||||
# Use PyAV to convert and add metadata
|
||||
input_container = av.open(wav_buffer)
|
||||
|
||||
blocks.append(create_vorbis_comment_block(comment_dict, last_block=True))
|
||||
# Create output with specified format
|
||||
output_buffer = io.BytesIO()
|
||||
output_container = av.open(output_buffer, mode='w', format=format)
|
||||
|
||||
new_flac_io = io.BytesIO()
|
||||
new_flac_io.write(b'fLaC')
|
||||
for block in blocks:
|
||||
new_flac_io.write(block)
|
||||
# Set metadata on the container
|
||||
for key, value in metadata.items():
|
||||
output_container.metadata[key] = value
|
||||
|
||||
new_flac_io.write(flac_io.read())
|
||||
return new_flac_io
|
||||
# Set up the output stream with appropriate properties
|
||||
input_container.streams.audio[0]
|
||||
if format == "opus":
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
||||
if quality == "64k":
|
||||
out_stream.bit_rate = 64000
|
||||
elif quality == "96k":
|
||||
out_stream.bit_rate = 96000
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "192k":
|
||||
out_stream.bit_rate = 192000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
elif format == "mp3":
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
||||
if quality == "V0":
|
||||
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||
out_stream.codec_context.qscale = 1
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
else: #format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||
|
||||
|
||||
# Copy frames from input to output
|
||||
for frame in input_container.decode(audio=0):
|
||||
frame.pts = None # Let PyAV handle timestamps
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output_container.mux(out_stream.encode(None))
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
|
||||
# Write the output to file
|
||||
output_buffer.seek(0)
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(output_buffer.getbuffer())
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
|
||||
return { "ui": { "audio": results } }
|
||||
|
||||
class SaveAudio:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
@@ -153,50 +211,70 @@ class SaveAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_audio"
|
||||
FUNCTION = "save_flac"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results: list[FileLocator] = []
|
||||
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
class SaveAudioMP3:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.flac"
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
buff = io.BytesIO()
|
||||
torchaudio.save(buff, waveform, audio["sample_rate"], format="FLAC")
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_mp3"
|
||||
|
||||
buff = insert_or_replace_vorbis_comment(buff, metadata)
|
||||
OUTPUT_NODE = True
|
||||
|
||||
with open(os.path.join(full_output_folder, file), 'wb') as f:
|
||||
f.write(buff.getbuffer())
|
||||
CATEGORY = "audio"
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
||||
|
||||
return { "ui": { "audio": results } }
|
||||
class SaveAudioOpus:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_opus"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
||||
|
||||
class PreviewAudio(SaveAudio):
|
||||
def __init__(self):
|
||||
@@ -248,7 +326,20 @@ NODE_CLASS_MAPPINGS = {
|
||||
"VAEEncodeAudio": VAEEncodeAudio,
|
||||
"VAEDecodeAudio": VAEDecodeAudio,
|
||||
"SaveAudio": SaveAudio,
|
||||
"SaveAudioMP3": SaveAudioMP3,
|
||||
"SaveAudioOpus": SaveAudioOpus,
|
||||
"LoadAudio": LoadAudio,
|
||||
"PreviewAudio": PreviewAudio,
|
||||
"ConditioningStableAudio": ConditioningStableAudio,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"EmptyLatentAudio": "Empty Latent Audio",
|
||||
"VAEEncodeAudio": "VAE Encode Audio",
|
||||
"VAEDecodeAudio": "VAE Decode Audio",
|
||||
"PreviewAudio": "Preview Audio",
|
||||
"LoadAudio": "Load Audio",
|
||||
"SaveAudio": "Save Audio (FLAC)",
|
||||
"SaveAudioMP3": "Save Audio (MP3)",
|
||||
"SaveAudioOpus": "Save Audio (Opus)",
|
||||
}
|
||||
|
||||
218
comfy_extras/nodes_camera_trajectory.py
Normal file
218
comfy_extras/nodes_camera_trajectory.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import nodes
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
|
||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
CAMERA_DICT = {
|
||||
"base_T_norm": 1.5,
|
||||
"base_angle": np.pi/3,
|
||||
"Static": { "angle":[0., 0., 0.], "T":[0., 0., 0.]},
|
||||
"Pan Up": { "angle":[0., 0., 0.], "T":[0., -1., 0.]},
|
||||
"Pan Down": { "angle":[0., 0., 0.], "T":[0.,1.,0.]},
|
||||
"Pan Left": { "angle":[0., 0., 0.], "T":[-1.,0.,0.]},
|
||||
"Pan Right": { "angle":[0., 0., 0.], "T": [1.,0.,0.]},
|
||||
"Zoom In": { "angle":[0., 0., 0.], "T": [0.,0.,2.]},
|
||||
"Zoom Out": { "angle":[0., 0., 0.], "T": [0.,0.,-2.]},
|
||||
"Anti Clockwise (ACW)": { "angle": [0., 0., -1.], "T":[0., 0., 0.]},
|
||||
"ClockWise (CW)": { "angle": [0., 0., 1.], "T":[0., 0., 0.]},
|
||||
}
|
||||
|
||||
|
||||
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
|
||||
|
||||
def get_relative_pose(cam_params):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
||||
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
||||
cam_to_origin = 0
|
||||
target_cam_c2w = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, -cam_to_origin],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1]
|
||||
])
|
||||
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
||||
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
||||
ret_poses = np.array(ret_poses, dtype=np.float32)
|
||||
return ret_poses
|
||||
|
||||
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
||||
|
||||
sample_wh_ratio = width / height
|
||||
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
||||
|
||||
if pose_wh_ratio > sample_wh_ratio:
|
||||
resized_ori_w = height * pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fx = resized_ori_w * cam_param.fx / width
|
||||
else:
|
||||
resized_ori_h = width / pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fy = resized_ori_h * cam_param.fy / height
|
||||
|
||||
intrinsic = np.asarray([[cam_param.fx * width,
|
||||
cam_param.fy * height,
|
||||
cam_param.cx * width,
|
||||
cam_param.cy * height]
|
||||
for cam_param in cam_params], dtype=np.float32)
|
||||
|
||||
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
||||
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
||||
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
||||
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
||||
plucker_embedding = plucker_embedding[None]
|
||||
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
||||
return plucker_embedding
|
||||
|
||||
class Camera(object):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
def __init__(self, entry):
|
||||
fx, fy, cx, cy = entry[1:5]
|
||||
self.fx = fx
|
||||
self.fy = fy
|
||||
self.cx = cx
|
||||
self.cy = cy
|
||||
c2w_mat = np.array(entry[7:]).reshape(4, 4)
|
||||
self.c2w_mat = c2w_mat
|
||||
self.w2c_mat = np.linalg.inv(c2w_mat)
|
||||
|
||||
def ray_condition(K, c2w, H, W, device):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
# c2w: B, V, 4, 4
|
||||
# K: B, V, 4
|
||||
|
||||
B = K.shape[0]
|
||||
|
||||
j, i = torch.meshgrid(
|
||||
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
||||
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
||||
indexing='ij'
|
||||
)
|
||||
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||
|
||||
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
||||
|
||||
zs = torch.ones_like(i) # [B, HxW]
|
||||
xs = (i - cx) / fx * zs
|
||||
ys = (j - cy) / fy * zs
|
||||
zs = zs.expand_as(ys)
|
||||
|
||||
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
||||
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
||||
|
||||
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
||||
rays_o = c2w[..., :3, 3] # B, V, 3
|
||||
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
||||
# c2w @ dirctions
|
||||
rays_dxo = torch.cross(rays_o, rays_d)
|
||||
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
||||
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
||||
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
||||
return plucker
|
||||
|
||||
def get_camera_motion(angle, T, speed, n=81):
|
||||
def compute_R_form_rad_angle(angles):
|
||||
theta_x, theta_y, theta_z = angles
|
||||
Rx = np.array([[1, 0, 0],
|
||||
[0, np.cos(theta_x), -np.sin(theta_x)],
|
||||
[0, np.sin(theta_x), np.cos(theta_x)]])
|
||||
|
||||
Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)],
|
||||
[0, 1, 0],
|
||||
[-np.sin(theta_y), 0, np.cos(theta_y)]])
|
||||
|
||||
Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0],
|
||||
[np.sin(theta_z), np.cos(theta_z), 0],
|
||||
[0, 0, 1]])
|
||||
|
||||
R = np.dot(Rz, np.dot(Ry, Rx))
|
||||
return R
|
||||
RT = []
|
||||
for i in range(n):
|
||||
_angle = (i/n)*speed*(CAMERA_DICT["base_angle"])*angle
|
||||
R = compute_R_form_rad_angle(_angle)
|
||||
_T=(i/n)*speed*(CAMERA_DICT["base_T_norm"])*(T.reshape(3,1))
|
||||
_RT = np.concatenate([R,_T], axis=1)
|
||||
RT.append(_RT)
|
||||
RT = np.stack(RT)
|
||||
return RT
|
||||
|
||||
class WanCameraEmbedding:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}),
|
||||
},
|
||||
"optional":{
|
||||
"speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}),
|
||||
"fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
|
||||
"fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
|
||||
"cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
|
||||
"cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT")
|
||||
RETURN_NAMES = ("camera_embedding","width","height","length")
|
||||
FUNCTION = "run"
|
||||
CATEGORY = "camera"
|
||||
|
||||
def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5):
|
||||
"""
|
||||
Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
|
||||
Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
|
||||
"""
|
||||
motion_list = [camera_pose]
|
||||
speed = speed
|
||||
angle = np.array(CAMERA_DICT[motion_list[0]]["angle"])
|
||||
T = np.array(CAMERA_DICT[motion_list[0]]["T"])
|
||||
RT = get_camera_motion(angle, T, speed, length)
|
||||
|
||||
trajs=[]
|
||||
for cp in RT.tolist():
|
||||
traj=[fx,fy,cx,cy,0,0]
|
||||
traj.extend(cp[0])
|
||||
traj.extend(cp[1])
|
||||
traj.extend(cp[2])
|
||||
traj.extend([0,0,0,1])
|
||||
trajs.append(traj)
|
||||
|
||||
cam_params = np.array([[float(x) for x in pose] for pose in trajs])
|
||||
cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1)
|
||||
control_camera_video = process_pose_params(cam_params, width=width, height=height)
|
||||
control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device())
|
||||
|
||||
control_camera_video = torch.concat(
|
||||
[
|
||||
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
|
||||
control_camera_video[:, :, 1:]
|
||||
], dim=2
|
||||
).transpose(1, 2)
|
||||
|
||||
# Reshape, transpose, and view into desired shape
|
||||
b, f, c, h, w = control_camera_video.shape
|
||||
control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
||||
control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
||||
|
||||
return (control_camera_video, width, height, length)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanCameraEmbedding": WanCameraEmbedding,
|
||||
}
|
||||
@@ -31,6 +31,7 @@ class T5TokenizerOptions:
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "_for_testing/conditioning"
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "set_options"
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ class HunyuanImageToVideo:
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
|
||||
"guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
|
||||
},
|
||||
"optional": {"start_image": ("IMAGE", ),
|
||||
}}
|
||||
@@ -101,10 +101,12 @@ class HunyuanImageToVideo:
|
||||
|
||||
if guidance_type == "v1 (concat)":
|
||||
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
|
||||
else:
|
||||
elif guidance_type == "v2 (replace)":
|
||||
cond = {'guiding_frame_index': 0}
|
||||
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
|
||||
out_latent["noise_mask"] = mask
|
||||
elif guidance_type == "custom":
|
||||
cond = {"ref_latent": concat_latent_image}
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, cond)
|
||||
|
||||
|
||||
@@ -10,6 +10,10 @@ from PIL.PngImagePlugin import PngInfo
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
from inspect import cleandoc
|
||||
import torch
|
||||
|
||||
from comfy.comfy_types import FileLocator
|
||||
|
||||
@@ -71,6 +75,24 @@ class ImageFromBatch:
|
||||
s = s_in[batch_index:batch_index + length].clone()
|
||||
return (s,)
|
||||
|
||||
|
||||
class ImageAddNoise:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
|
||||
"strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "repeat"
|
||||
|
||||
CATEGORY = "image"
|
||||
|
||||
def repeat(self, image, seed, strength):
|
||||
generator = torch.manual_seed(seed)
|
||||
s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
|
||||
return (s,)
|
||||
|
||||
class SaveAnimatedWEBP:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
@@ -190,10 +212,110 @@ class SaveAnimatedPNG:
|
||||
|
||||
return { "ui": { "images": results, "animated": (True,)} }
|
||||
|
||||
class SVG:
|
||||
"""
|
||||
Stores SVG representations via a list of BytesIO objects.
|
||||
"""
|
||||
def __init__(self, data: list[BytesIO]):
|
||||
self.data = data
|
||||
|
||||
def combine(self, other: 'SVG') -> 'SVG':
|
||||
return SVG(self.data + other.data)
|
||||
|
||||
@staticmethod
|
||||
def combine_all(svgs: list['SVG']) -> 'SVG':
|
||||
all_svgs_list: list[BytesIO] = []
|
||||
for svg_item in svgs:
|
||||
all_svgs_list.extend(svg_item.data)
|
||||
return SVG(all_svgs_list)
|
||||
|
||||
class SaveSVGNode:
|
||||
"""
|
||||
Save SVG files on disk.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
RETURN_TYPES = ()
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "save_svg"
|
||||
CATEGORY = "image/save" # Changed
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"svg": ("SVG",), # Changed
|
||||
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO"
|
||||
}
|
||||
}
|
||||
|
||||
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results = list()
|
||||
|
||||
# Prepare metadata JSON
|
||||
metadata_dict = {}
|
||||
if prompt is not None:
|
||||
metadata_dict["prompt"] = prompt
|
||||
if extra_pnginfo is not None:
|
||||
metadata_dict.update(extra_pnginfo)
|
||||
|
||||
# Convert metadata to JSON string
|
||||
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
|
||||
|
||||
for batch_number, svg_bytes in enumerate(svg.data):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.svg"
|
||||
|
||||
# Read SVG content
|
||||
svg_bytes.seek(0)
|
||||
svg_content = svg_bytes.read().decode('utf-8')
|
||||
|
||||
# Inject metadata if available
|
||||
if metadata_json:
|
||||
# Create metadata element with CDATA section
|
||||
metadata_element = f""" <metadata>
|
||||
<![CDATA[
|
||||
{metadata_json}
|
||||
]]>
|
||||
</metadata>
|
||||
"""
|
||||
# Insert metadata after opening svg tag using regex with a replacement function
|
||||
def replacement(match):
|
||||
# match.group(1) contains the captured <svg> tag
|
||||
return match.group(1) + '\n' + metadata_element
|
||||
|
||||
# Apply the substitution
|
||||
svg_content = re.sub(r'(<svg[^>]*>)', replacement, svg_content, flags=re.UNICODE)
|
||||
|
||||
# Write the modified SVG to file
|
||||
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
|
||||
svg_file.write(svg_content.encode('utf-8'))
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
return { "ui": { "images": results } }
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
"ImageFromBatch": ImageFromBatch,
|
||||
"ImageAddNoise": ImageAddNoise,
|
||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||
"SaveSVGNode": SaveSVGNode,
|
||||
}
|
||||
|
||||
@@ -2,6 +2,10 @@ import nodes
|
||||
import folder_paths
|
||||
import os
|
||||
|
||||
from comfy.comfy_types import IO
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
|
||||
|
||||
def normalize_path(path):
|
||||
return path.replace('\\', '/')
|
||||
|
||||
@@ -21,8 +25,8 @@ class Load3D():
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info")
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
@@ -41,7 +45,14 @@ class Load3D():
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info']
|
||||
video = None
|
||||
|
||||
if image['recording'] != "":
|
||||
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
||||
|
||||
video = VideoFromFile(recording_video_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
|
||||
|
||||
class Load3DAnimation():
|
||||
@classmethod
|
||||
@@ -59,8 +70,8 @@ class Load3DAnimation():
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info")
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
@@ -77,7 +88,14 @@ class Load3DAnimation():
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, image['camera_info']
|
||||
video = None
|
||||
|
||||
if image['recording'] != "":
|
||||
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
||||
|
||||
video = VideoFromFile(recording_video_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, image['camera_info'], video
|
||||
|
||||
class Preview3D():
|
||||
@classmethod
|
||||
|
||||
@@ -21,6 +21,21 @@ class String(ComfyNodeABC):
|
||||
return (value,)
|
||||
|
||||
|
||||
class StringMultiline(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {"value": (IO.STRING, {"multiline": True,},)},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/primitive"
|
||||
|
||||
def execute(self, value: str) -> tuple[str]:
|
||||
return (value,)
|
||||
|
||||
|
||||
class Int(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
@@ -68,6 +83,7 @@ class Boolean(ComfyNodeABC):
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PrimitiveString": String,
|
||||
"PrimitiveStringMultiline": StringMultiline,
|
||||
"PrimitiveInt": Int,
|
||||
"PrimitiveFloat": Float,
|
||||
"PrimitiveBoolean": Boolean,
|
||||
@@ -75,6 +91,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PrimitiveString": "String",
|
||||
"PrimitiveStringMultiline": "String (Multiline)",
|
||||
"PrimitiveInt": "Int",
|
||||
"PrimitiveFloat": "Float",
|
||||
"PrimitiveBoolean": "Boolean",
|
||||
|
||||
323
comfy_extras/nodes_string.py
Normal file
323
comfy_extras/nodes_string.py
Normal file
@@ -0,0 +1,323 @@
|
||||
import re
|
||||
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
|
||||
class StringConcatenate():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string_a": (IO.STRING, {"multiline": True}),
|
||||
"string_b": (IO.STRING, {"multiline": True}),
|
||||
"delimiter": (IO.STRING, {"multiline": False, "default": ""})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string_a, string_b, delimiter, **kwargs):
|
||||
return delimiter.join((string_a, string_b)),
|
||||
|
||||
class StringSubstring():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"start": (IO.INT, {}),
|
||||
"end": (IO.INT, {}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, start, end, **kwargs):
|
||||
return string[start:end],
|
||||
|
||||
class StringLength():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.INT,)
|
||||
RETURN_NAMES = ("length",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, **kwargs):
|
||||
length = len(string)
|
||||
|
||||
return length,
|
||||
|
||||
class CaseConverter():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["UPPERCASE", "lowercase", "Capitalize", "Title Case"]})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, mode, **kwargs):
|
||||
if mode == "UPPERCASE":
|
||||
result = string.upper()
|
||||
elif mode == "lowercase":
|
||||
result = string.lower()
|
||||
elif mode == "Capitalize":
|
||||
result = string.capitalize()
|
||||
elif mode == "Title Case":
|
||||
result = string.title()
|
||||
else:
|
||||
result = string
|
||||
|
||||
return result,
|
||||
|
||||
|
||||
class StringTrim():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["Both", "Left", "Right"]})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, mode, **kwargs):
|
||||
if mode == "Both":
|
||||
result = string.strip()
|
||||
elif mode == "Left":
|
||||
result = string.lstrip()
|
||||
elif mode == "Right":
|
||||
result = string.rstrip()
|
||||
else:
|
||||
result = string
|
||||
|
||||
return result,
|
||||
|
||||
class StringReplace():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"find": (IO.STRING, {"multiline": True}),
|
||||
"replace": (IO.STRING, {"multiline": True})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, find, replace, **kwargs):
|
||||
result = string.replace(find, replace)
|
||||
return result,
|
||||
|
||||
|
||||
class StringContains():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"substring": (IO.STRING, {"multiline": True}),
|
||||
"case_sensitive": (IO.BOOLEAN, {"default": True})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
RETURN_NAMES = ("contains",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, substring, case_sensitive, **kwargs):
|
||||
if case_sensitive:
|
||||
contains = substring in string
|
||||
else:
|
||||
contains = substring.lower() in string.lower()
|
||||
|
||||
return contains,
|
||||
|
||||
|
||||
class StringCompare():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string_a": (IO.STRING, {"multiline": True}),
|
||||
"string_b": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["Starts With", "Ends With", "Equal"]}),
|
||||
"case_sensitive": (IO.BOOLEAN, {"default": True})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string_a, string_b, mode, case_sensitive, **kwargs):
|
||||
if case_sensitive:
|
||||
a = string_a
|
||||
b = string_b
|
||||
else:
|
||||
a = string_a.lower()
|
||||
b = string_b.lower()
|
||||
|
||||
if mode == "Equal":
|
||||
return a == b,
|
||||
elif mode == "Starts With":
|
||||
return a.startswith(b),
|
||||
elif mode == "Ends With":
|
||||
return a.endswith(b),
|
||||
|
||||
class RegexMatch():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
RETURN_NAMES = ("matches",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, **kwargs):
|
||||
flags = 0
|
||||
|
||||
if case_insensitive:
|
||||
flags |= re.IGNORECASE
|
||||
if multiline:
|
||||
flags |= re.MULTILINE
|
||||
if dotall:
|
||||
flags |= re.DOTALL
|
||||
|
||||
try:
|
||||
match = re.search(regex_pattern, string, flags)
|
||||
result = match is not None
|
||||
|
||||
except re.error:
|
||||
result = False
|
||||
|
||||
return result,
|
||||
|
||||
|
||||
class RegexExtract():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["First Match", "All Matches", "First Group", "All Groups"]}),
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False}),
|
||||
"group_index": (IO.INT, {"default": 1, "min": 0, "max": 100})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index, **kwargs):
|
||||
join_delimiter = "\n"
|
||||
|
||||
flags = 0
|
||||
if case_insensitive:
|
||||
flags |= re.IGNORECASE
|
||||
if multiline:
|
||||
flags |= re.MULTILINE
|
||||
if dotall:
|
||||
flags |= re.DOTALL
|
||||
|
||||
try:
|
||||
if mode == "First Match":
|
||||
match = re.search(regex_pattern, string, flags)
|
||||
if match:
|
||||
result = match.group(0)
|
||||
else:
|
||||
result = ""
|
||||
|
||||
elif mode == "All Matches":
|
||||
matches = re.findall(regex_pattern, string, flags)
|
||||
if matches:
|
||||
if isinstance(matches[0], tuple):
|
||||
result = join_delimiter.join([m[0] for m in matches])
|
||||
else:
|
||||
result = join_delimiter.join(matches)
|
||||
else:
|
||||
result = ""
|
||||
|
||||
elif mode == "First Group":
|
||||
match = re.search(regex_pattern, string, flags)
|
||||
if match and len(match.groups()) >= group_index:
|
||||
result = match.group(group_index)
|
||||
else:
|
||||
result = ""
|
||||
|
||||
elif mode == "All Groups":
|
||||
matches = re.finditer(regex_pattern, string, flags)
|
||||
results = []
|
||||
for match in matches:
|
||||
if match.groups() and len(match.groups()) >= group_index:
|
||||
results.append(match.group(group_index))
|
||||
result = join_delimiter.join(results)
|
||||
else:
|
||||
result = ""
|
||||
|
||||
except re.error:
|
||||
result = ""
|
||||
|
||||
return result,
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StringConcatenate": StringConcatenate,
|
||||
"StringSubstring": StringSubstring,
|
||||
"StringLength": StringLength,
|
||||
"CaseConverter": CaseConverter,
|
||||
"StringTrim": StringTrim,
|
||||
"StringReplace": StringReplace,
|
||||
"StringContains": StringContains,
|
||||
"StringCompare": StringCompare,
|
||||
"RegexMatch": RegexMatch,
|
||||
"RegexExtract": RegexExtract
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StringConcatenate": "Concatenate",
|
||||
"StringSubstring": "Substring",
|
||||
"StringLength": "Length",
|
||||
"CaseConverter": "Case Converter",
|
||||
"StringTrim": "Trim",
|
||||
"StringReplace": "Replace",
|
||||
"StringContains": "Contains",
|
||||
"StringCompare": "Compare",
|
||||
"RegexMatch": "Regex Match",
|
||||
"RegexExtract": "Regex Extract"
|
||||
}
|
||||
@@ -297,6 +297,52 @@ class TrimVideoLatent:
|
||||
samples_out["samples"] = s1[:, :, trim_amount:]
|
||||
return (samples_out,)
|
||||
|
||||
class WanCameraImageToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"camera_conditions": ("WAN_CAMERA_EMBEDDING", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
||||
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
|
||||
|
||||
if camera_conditions is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions})
|
||||
negative = node_helpers.conditioning_set_values(negative, {'camera_conditions': camera_conditions})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanImageToVideo": WanImageToVideo,
|
||||
@@ -305,4 +351,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
||||
"WanVaceToVideo": WanVaceToVideo,
|
||||
"TrimVideoLatent": TrimVideoLatent,
|
||||
"WanCameraImageToVideo": WanCameraImageToVideo,
|
||||
}
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.31"
|
||||
__version__ = "0.3.34"
|
||||
|
||||
@@ -146,6 +146,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
input_data_all[x] = [unique_id]
|
||||
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
return input_data_all, missing_keys
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
28
fix_torch.py
28
fix_torch.py
@@ -1,28 +0,0 @@
|
||||
import importlib.util
|
||||
import shutil
|
||||
import os
|
||||
import ctypes
|
||||
import logging
|
||||
|
||||
|
||||
def fix_pytorch_libomp():
|
||||
"""
|
||||
Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed.
|
||||
"""
|
||||
torch_spec = importlib.util.find_spec("torch")
|
||||
for folder in torch_spec.submodule_search_locations:
|
||||
lib_folder = os.path.join(folder, "lib")
|
||||
test_file = os.path.join(lib_folder, "fbgemm.dll")
|
||||
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
|
||||
if os.path.exists(dest):
|
||||
break
|
||||
|
||||
with open(test_file, "rb") as f:
|
||||
contents = f.read()
|
||||
if b"libomp140.x86_64.dll" not in contents:
|
||||
break
|
||||
try:
|
||||
ctypes.cdll.LoadLibrary(test_file)
|
||||
except FileNotFoundError:
|
||||
logging.warning("Detected pytorch version with libomp issue, patching.")
|
||||
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
|
||||
9
main.py
9
main.py
@@ -125,13 +125,6 @@ if __name__ == "__main__":
|
||||
|
||||
import cuda_malloc
|
||||
|
||||
if args.windows_standalone_build:
|
||||
try:
|
||||
from fix_torch import fix_pytorch_libomp
|
||||
fix_pytorch_libomp()
|
||||
except:
|
||||
pass
|
||||
|
||||
import comfy.utils
|
||||
|
||||
import execution
|
||||
@@ -270,7 +263,7 @@ def start_comfyui(asyncio_loop=None):
|
||||
q = execution.PromptQueue(prompt_server)
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
|
||||
56
nodes.py
56
nodes.py
@@ -246,6 +246,9 @@ class ConditioningZeroOut:
|
||||
pooled_output = d.get("pooled_output", None)
|
||||
if pooled_output is not None:
|
||||
d["pooled_output"] = torch.zeros_like(pooled_output)
|
||||
conditioning_lyrics = d.get("conditioning_lyrics", None)
|
||||
if conditioning_lyrics is not None:
|
||||
d["conditioning_lyrics"] = torch.zeros_like(conditioning_lyrics)
|
||||
n = [torch.zeros_like(t[0]), d]
|
||||
c.append(n)
|
||||
return (c, )
|
||||
@@ -917,7 +920,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@@ -1937,7 +1940,7 @@ class ImagePadForOutpaint:
|
||||
|
||||
mask[top:top + d2, left:left + d3] = t
|
||||
|
||||
return (new_image, mask)
|
||||
return (new_image, mask.unsqueeze(0))
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
@@ -2258,12 +2261,11 @@ def init_builtin_extra_nodes():
|
||||
"nodes_optimalsteps.py",
|
||||
"nodes_hidream.py",
|
||||
"nodes_fresca.py",
|
||||
"nodes_apg.py",
|
||||
"nodes_preview_any.py",
|
||||
]
|
||||
|
||||
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
|
||||
api_nodes_files = [
|
||||
"nodes_api.py",
|
||||
"nodes_ace.py",
|
||||
"nodes_string.py",
|
||||
"nodes_camera_trajectory.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
@@ -2271,6 +2273,29 @@ def init_builtin_extra_nodes():
|
||||
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
||||
import_failed.append(node_file)
|
||||
|
||||
return import_failed
|
||||
|
||||
|
||||
def init_builtin_api_nodes():
|
||||
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
|
||||
api_nodes_files = [
|
||||
"nodes_ideogram.py",
|
||||
"nodes_openai.py",
|
||||
"nodes_minimax.py",
|
||||
"nodes_veo2.py",
|
||||
"nodes_kling.py",
|
||||
"nodes_bfl.py",
|
||||
"nodes_luma.py",
|
||||
"nodes_recraft.py",
|
||||
"nodes_pixverse.py",
|
||||
"nodes_stability.py",
|
||||
"nodes_pika.py",
|
||||
]
|
||||
|
||||
if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
|
||||
return api_nodes_files
|
||||
|
||||
import_failed = []
|
||||
for node_file in api_nodes_files:
|
||||
if not load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
|
||||
import_failed.append(node_file)
|
||||
@@ -2278,14 +2303,29 @@ def init_builtin_extra_nodes():
|
||||
return import_failed
|
||||
|
||||
|
||||
def init_extra_nodes(init_custom_nodes=True):
|
||||
def init_extra_nodes(init_custom_nodes=True, init_api_nodes=True):
|
||||
import_failed = init_builtin_extra_nodes()
|
||||
|
||||
import_failed_api = []
|
||||
if init_api_nodes:
|
||||
import_failed_api = init_builtin_api_nodes()
|
||||
|
||||
if init_custom_nodes:
|
||||
init_external_custom_nodes()
|
||||
else:
|
||||
logging.info("Skipping loading of custom nodes")
|
||||
|
||||
if len(import_failed_api) > 0:
|
||||
logging.warning("WARNING: some comfy_api_nodes/ nodes did not import correctly. This may be because they are missing some dependencies.\n")
|
||||
for node in import_failed_api:
|
||||
logging.warning("IMPORT FAILED: {}".format(node))
|
||||
logging.warning("\nThis issue might be caused by new missing dependencies added the last time you updated ComfyUI.")
|
||||
if args.windows_standalone_build:
|
||||
logging.warning("Please run the update script: update/update_comfyui.bat")
|
||||
else:
|
||||
logging.warning("Please do a: pip install -r requirements.txt")
|
||||
logging.warning("")
|
||||
|
||||
if len(import_failed) > 0:
|
||||
logging.warning("WARNING: some comfy_extras/ nodes did not import correctly. This may be because they are missing some dependencies.\n")
|
||||
for node in import_failed:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.31"
|
||||
version = "0.3.34"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.18.6
|
||||
comfyui-workflow-templates==0.1.3
|
||||
comfyui-frontend-package==1.19.9
|
||||
comfyui-workflow-templates==0.1.14
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
||||
@@ -101,6 +101,14 @@ prompt_text = """
|
||||
|
||||
def queue_prompt(prompt):
|
||||
p = {"prompt": prompt}
|
||||
|
||||
# If the workflow contains API nodes, you can add a Comfy API key to the `extra_data`` field of the payload.
|
||||
# p["extra_data"] = {
|
||||
# "api_key_comfy_org": "comfyui-87d01e28d*******************************************************" # replace with real key
|
||||
# }
|
||||
# See: https://docs.comfy.org/tutorials/api-nodes/overview
|
||||
# Generate a key here: https://platform.comfy.org/login
|
||||
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
req = request.Request("http://127.0.0.1:8188/prompt", data=data)
|
||||
request.urlopen(req)
|
||||
|
||||
15
server.py
15
server.py
@@ -32,12 +32,13 @@ from app.frontend_management import FrontendManager
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
TEXT = 3
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
try:
|
||||
@@ -878,3 +879,15 @@ class PromptServer():
|
||||
logging.warning(traceback.format_exc())
|
||||
|
||||
return json_data
|
||||
|
||||
def send_progress_text(
|
||||
self, text: Union[bytes, bytearray, str], node_id: str, sid=None
|
||||
):
|
||||
if isinstance(text, str):
|
||||
text = text.encode("utf-8")
|
||||
node_id_bytes = str(node_id).encode("utf-8")
|
||||
|
||||
# Pack the node_id length as a 4-byte unsigned integer, followed by the node_id bytes
|
||||
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
|
||||
|
||||
self.send_sync(BinaryEventTypes.TEXT, message, sid)
|
||||
|
||||
297
tests-unit/comfy_api_nodes_test/mapper_utils_test.py
Normal file
297
tests-unit/comfy_api_nodes_test/mapper_utils_test.py
Normal file
@@ -0,0 +1,297 @@
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
|
||||
|
||||
def test_model_field_to_float_input():
|
||||
"""Tests mapping a float field with constraints."""
|
||||
|
||||
class ModelWithFloatField(BaseModel):
|
||||
cfg_scale: Optional[float] = Field(
|
||||
default=0.5,
|
||||
description="Flexibility in video generation",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
multiple_of=0.001,
|
||||
)
|
||||
|
||||
expected_output = (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.5,
|
||||
"tooltip": "Flexibility in video generation",
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.001,
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.FLOAT, ModelWithFloatField, "cfg_scale"
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_float_input_no_constraints():
|
||||
"""Tests mapping a float field with no constraints."""
|
||||
|
||||
class ModelWithFloatField(BaseModel):
|
||||
cfg_scale: Optional[float] = Field(default=0.5)
|
||||
|
||||
expected_output = (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.5,
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.FLOAT, ModelWithFloatField, "cfg_scale"
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_int_input():
|
||||
"""Tests mapping an int field with constraints."""
|
||||
|
||||
class ModelWithIntField(BaseModel):
|
||||
num_frames: Optional[int] = Field(
|
||||
default=10,
|
||||
description="Number of frames to generate",
|
||||
ge=1,
|
||||
le=100,
|
||||
multiple_of=1,
|
||||
)
|
||||
|
||||
expected_output = (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 10,
|
||||
"tooltip": "Number of frames to generate",
|
||||
"min": 1,
|
||||
"max": 100,
|
||||
"step": 1,
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(IO.INT, ModelWithIntField, "num_frames")
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_string_input():
|
||||
"""Tests mapping a string field."""
|
||||
|
||||
class ModelWithStringField(BaseModel):
|
||||
prompt: Optional[str] = Field(
|
||||
default="A beautiful sunset over a calm ocean",
|
||||
description="A prompt for the video generation",
|
||||
)
|
||||
|
||||
expected_output = (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "A beautiful sunset over a calm ocean",
|
||||
"tooltip": "A prompt for the video generation",
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(IO.STRING, ModelWithStringField, "prompt")
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_string_input_multiline():
|
||||
"""Tests mapping a string field."""
|
||||
|
||||
class ModelWithStringField(BaseModel):
|
||||
prompt: Optional[str] = Field(
|
||||
default="A beautiful sunset over a calm ocean",
|
||||
description="A prompt for the video generation",
|
||||
)
|
||||
|
||||
expected_output = (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "A beautiful sunset over a calm ocean",
|
||||
"tooltip": "A prompt for the video generation",
|
||||
"multiline": True,
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.STRING, ModelWithStringField, "prompt", multiline=True
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_combo_input():
|
||||
"""Tests mapping a combo field."""
|
||||
|
||||
class MockEnum(str, Enum):
|
||||
option_1 = "option 1"
|
||||
option_2 = "option 2"
|
||||
option_3 = "option 3"
|
||||
|
||||
class ModelWithComboField(BaseModel):
|
||||
model_name: Optional[MockEnum] = Field("option 1", description="Model Name")
|
||||
|
||||
expected_output = (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["option 1", "option 2", "option 3"],
|
||||
"default": "option 1",
|
||||
"tooltip": "Model Name",
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.COMBO, ModelWithComboField, "model_name", enum_type=MockEnum
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_combo_input_no_options():
|
||||
"""Tests mapping a combo field with no options."""
|
||||
|
||||
class ModelWithComboField(BaseModel):
|
||||
model_name: Optional[str] = Field(description="Model Name")
|
||||
|
||||
expected_output = (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "Model Name",
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.COMBO, ModelWithComboField, "model_name"
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_image_input():
|
||||
"""Tests mapping an image field."""
|
||||
|
||||
class ModelWithImageField(BaseModel):
|
||||
image: Optional[str] = Field(
|
||||
default=None,
|
||||
description="An image for the video generation",
|
||||
)
|
||||
|
||||
expected_output = (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "An image for the video generation",
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(IO.IMAGE, ModelWithImageField, "image")
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_node_input_no_description():
|
||||
"""Tests mapping a field with no description."""
|
||||
|
||||
class ModelWithNoDescriptionField(BaseModel):
|
||||
field: Optional[str] = Field(default="default value")
|
||||
|
||||
expected_output = (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "default value",
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.STRING, ModelWithNoDescriptionField, "field"
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_node_input_no_default():
|
||||
"""Tests mapping a field with no default."""
|
||||
|
||||
class ModelWithNoDefaultField(BaseModel):
|
||||
field: Optional[str] = Field(description="A field with no default")
|
||||
|
||||
expected_output = (
|
||||
IO.STRING,
|
||||
{
|
||||
"tooltip": "A field with no default",
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.STRING, ModelWithNoDefaultField, "field"
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_node_input_no_metadata():
|
||||
"""Tests mapping a field with no metadata or properties defined on the schema."""
|
||||
|
||||
class ModelWithNoMetadataField(BaseModel):
|
||||
field: Optional[str] = Field()
|
||||
|
||||
expected_output = (
|
||||
IO.STRING,
|
||||
{},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.STRING, ModelWithNoMetadataField, "field"
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_node_input_default_is_none():
|
||||
"""
|
||||
Tests mapping a field with a default of `None`.
|
||||
I.e., the default field should be included as the schema explicitly sets it to `None`.
|
||||
"""
|
||||
|
||||
class ModelWithNoneDefaultField(BaseModel):
|
||||
field: Optional[str] = Field(
|
||||
default=None, description="A field with a default of None"
|
||||
)
|
||||
|
||||
expected_output = (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "A field with a default of None",
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.STRING, ModelWithNoneDefaultField, "field"
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
239
tests-unit/comfy_api_test/video_types_test.py
Normal file
239
tests-unit/comfy_api_test/video_types_test.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import pytest
|
||||
import torch
|
||||
import tempfile
|
||||
import os
|
||||
import av
|
||||
import io
|
||||
from fractions import Fraction
|
||||
from comfy_api.input_impl.video_types import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.util.video_types import VideoComponents
|
||||
from comfy_api.input.basic_types import AudioInput
|
||||
from av.error import InvalidDataError
|
||||
|
||||
EPSILON = 0.0001
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_images():
|
||||
"""3-frame 2x2 RGB video tensor"""
|
||||
return torch.rand(3, 2, 2, 3)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_audio():
|
||||
"""Stereo audio with 44.1kHz sample rate"""
|
||||
return AudioInput(
|
||||
{
|
||||
"waveform": torch.rand(1, 2, 1000),
|
||||
"sample_rate": 44100,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_components(sample_images, sample_audio):
|
||||
"""VideoComponents with images, audio, and metadata"""
|
||||
return VideoComponents(
|
||||
images=sample_images,
|
||||
audio=sample_audio,
|
||||
frame_rate=Fraction(30),
|
||||
metadata={"test": "metadata"},
|
||||
)
|
||||
|
||||
|
||||
def create_test_video(width=4, height=4, frames=3, fps=30):
|
||||
"""Helper to create a temporary video file"""
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
||||
with av.open(tmp.name, mode="w") as container:
|
||||
stream = container.add_stream("h264", rate=fps)
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
for i in range(frames):
|
||||
frame = av.VideoFrame.from_ndarray(
|
||||
torch.ones(height, width, 3, dtype=torch.uint8).numpy() * (i * 85),
|
||||
format="rgb24",
|
||||
)
|
||||
frame = frame.reformat(format="yuv420p")
|
||||
packet = stream.encode(frame)
|
||||
container.mux(packet)
|
||||
|
||||
# Flush
|
||||
packet = stream.encode(None)
|
||||
container.mux(packet)
|
||||
|
||||
return tmp.name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_video_file():
|
||||
"""4x4 video with 3 frames at 30fps"""
|
||||
file_path = create_test_video()
|
||||
yield file_path
|
||||
os.unlink(file_path)
|
||||
|
||||
|
||||
def test_video_from_components_get_duration(video_components):
|
||||
"""Duration calculated correctly from frame count and frame rate"""
|
||||
video = VideoFromComponents(video_components)
|
||||
duration = video.get_duration()
|
||||
|
||||
expected_duration = 3.0 / 30.0
|
||||
assert duration == pytest.approx(expected_duration)
|
||||
|
||||
|
||||
def test_video_from_components_get_duration_different_frame_rates(sample_images):
|
||||
"""Duration correct for different frame rates including fractional"""
|
||||
# Test with 60 fps
|
||||
components_60fps = VideoComponents(images=sample_images, frame_rate=Fraction(60))
|
||||
video_60fps = VideoFromComponents(components_60fps)
|
||||
assert video_60fps.get_duration() == pytest.approx(3.0 / 60.0)
|
||||
|
||||
# Test with fractional frame rate (23.976fps)
|
||||
components_frac = VideoComponents(
|
||||
images=sample_images, frame_rate=Fraction(24000, 1001)
|
||||
)
|
||||
video_frac = VideoFromComponents(components_frac)
|
||||
expected_frac = 3.0 / (24000.0 / 1001.0)
|
||||
assert video_frac.get_duration() == pytest.approx(expected_frac)
|
||||
|
||||
|
||||
def test_video_from_components_get_duration_empty_video():
|
||||
"""Duration is zero for empty video"""
|
||||
empty_components = VideoComponents(
|
||||
images=torch.zeros(0, 2, 2, 3), frame_rate=Fraction(30)
|
||||
)
|
||||
video = VideoFromComponents(empty_components)
|
||||
assert video.get_duration() == 0.0
|
||||
|
||||
|
||||
def test_video_from_components_get_dimensions(video_components):
|
||||
"""Dimensions returned correctly from image tensor shape"""
|
||||
video = VideoFromComponents(video_components)
|
||||
width, height = video.get_dimensions()
|
||||
assert width == 2
|
||||
assert height == 2
|
||||
|
||||
|
||||
def test_video_from_file_get_duration(simple_video_file):
|
||||
"""Duration extracted from file metadata"""
|
||||
video = VideoFromFile(simple_video_file)
|
||||
duration = video.get_duration()
|
||||
assert duration == pytest.approx(0.1, abs=0.01)
|
||||
|
||||
|
||||
def test_video_from_file_get_dimensions(simple_video_file):
|
||||
"""Dimensions read from stream without decoding frames"""
|
||||
video = VideoFromFile(simple_video_file)
|
||||
width, height = video.get_dimensions()
|
||||
assert width == 4
|
||||
assert height == 4
|
||||
|
||||
|
||||
def test_video_from_file_bytesio_input():
|
||||
"""VideoFromFile works with BytesIO input"""
|
||||
buffer = io.BytesIO()
|
||||
with av.open(buffer, mode="w", format="mp4") as container:
|
||||
stream = container.add_stream("h264", rate=30)
|
||||
stream.width = 2
|
||||
stream.height = 2
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
frame = av.VideoFrame.from_ndarray(
|
||||
torch.zeros(2, 2, 3, dtype=torch.uint8).numpy(), format="rgb24"
|
||||
)
|
||||
frame = frame.reformat(format="yuv420p")
|
||||
packet = stream.encode(frame)
|
||||
container.mux(packet)
|
||||
packet = stream.encode(None)
|
||||
container.mux(packet)
|
||||
|
||||
buffer.seek(0)
|
||||
video = VideoFromFile(buffer)
|
||||
|
||||
assert video.get_dimensions() == (2, 2)
|
||||
assert video.get_duration() == pytest.approx(1 / 30, abs=0.01)
|
||||
|
||||
|
||||
def test_video_from_file_invalid_file_error():
|
||||
"""InvalidDataError raised for non-video files"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp:
|
||||
tmp.write(b"not a video file")
|
||||
tmp.flush()
|
||||
tmp_name = tmp.name
|
||||
|
||||
try:
|
||||
with pytest.raises(InvalidDataError):
|
||||
video = VideoFromFile(tmp_name)
|
||||
video.get_dimensions()
|
||||
finally:
|
||||
os.unlink(tmp_name)
|
||||
|
||||
|
||||
def test_video_from_file_audio_only_error():
|
||||
"""ValueError raised for audio-only files"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp:
|
||||
tmp_name = tmp.name
|
||||
|
||||
try:
|
||||
with av.open(tmp_name, mode="w") as container:
|
||||
stream = container.add_stream("aac", rate=44100)
|
||||
stream.sample_rate = 44100
|
||||
stream.format = "fltp"
|
||||
|
||||
audio_data = torch.zeros(1, 1024).numpy()
|
||||
audio_frame = av.AudioFrame.from_ndarray(
|
||||
audio_data, format="fltp", layout="mono"
|
||||
)
|
||||
audio_frame.sample_rate = 44100
|
||||
audio_frame.pts = 0
|
||||
packet = stream.encode(audio_frame)
|
||||
container.mux(packet)
|
||||
|
||||
for packet in stream.encode(None):
|
||||
container.mux(packet)
|
||||
|
||||
with pytest.raises(ValueError, match="No video stream found"):
|
||||
video = VideoFromFile(tmp_name)
|
||||
video.get_dimensions()
|
||||
finally:
|
||||
os.unlink(tmp_name)
|
||||
|
||||
|
||||
def test_single_frame_video():
|
||||
"""Single frame video has correct duration"""
|
||||
components = VideoComponents(
|
||||
images=torch.rand(1, 10, 10, 3), frame_rate=Fraction(1)
|
||||
)
|
||||
video = VideoFromComponents(components)
|
||||
assert video.get_duration() == 1.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"frame_rate,expected_fps",
|
||||
[
|
||||
(Fraction(24000, 1001), 24000 / 1001),
|
||||
(Fraction(30000, 1001), 30000 / 1001),
|
||||
(Fraction(25, 1), 25.0),
|
||||
(Fraction(50, 2), 25.0),
|
||||
],
|
||||
)
|
||||
def test_fractional_frame_rates(frame_rate, expected_fps):
|
||||
"""Duration calculated correctly for various fractional frame rates"""
|
||||
components = VideoComponents(images=torch.rand(100, 4, 4, 3), frame_rate=frame_rate)
|
||||
video = VideoFromComponents(components)
|
||||
duration = video.get_duration()
|
||||
expected_duration = 100.0 / expected_fps
|
||||
assert duration == pytest.approx(expected_duration)
|
||||
|
||||
|
||||
def test_duration_consistency(video_components):
|
||||
"""get_duration() consistent with manual calculation from components"""
|
||||
video = VideoFromComponents(video_components)
|
||||
|
||||
duration = video.get_duration()
|
||||
components = video.get_components()
|
||||
manual_duration = float(components.images.shape[0] / components.frame_rate)
|
||||
|
||||
assert duration == pytest.approx(manual_duration)
|
||||
Reference in New Issue
Block a user