Compare commits
37 Commits
v0.3.29
...
desktop-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
188b383c35 | ||
|
|
2c1d686ec6 | ||
|
|
e8ddc2be95 | ||
|
|
dea1c7474a | ||
|
|
154f2911aa | ||
|
|
3eaad0590e | ||
|
|
7eaff81be1 | ||
|
|
21a11ef817 | ||
|
|
552615235d | ||
|
|
0738e4ea5d | ||
|
|
92cdc692f4 | ||
|
|
2d6805ce57 | ||
|
|
a8f63c0d5b | ||
|
|
454a635c1b | ||
|
|
966c43ce26 | ||
|
|
3ab231f01f | ||
|
|
1f3fba2af5 | ||
|
|
5d0d4ee98a | ||
|
|
9d57b8afd8 | ||
|
|
5d51794607 | ||
|
|
ce22f687cc | ||
|
|
b6fd3ffd10 | ||
|
|
11b72c9c55 | ||
|
|
2c735c13b4 | ||
|
|
fd27494441 | ||
|
|
f43e1d7f41 | ||
|
|
4486b0d0ff | ||
|
|
636d4bfb89 | ||
|
|
dc300a4569 | ||
|
|
f3b09b9f2d | ||
|
|
7ecd5e9614 | ||
|
|
2383a39e3b | ||
|
|
34e06bf7ec | ||
|
|
55822faa05 | ||
|
|
880c205df1 | ||
|
|
3dc240d089 | ||
|
|
19373aee75 |
6
.github/workflows/stable-release.yml
vendored
6
.github/workflows/stable-release.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: ${{ inputs.git_tag }}
|
ref: ${{ inputs.git_tag }}
|
||||||
fetch-depth: 0
|
fetch-depth: 150
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- uses: actions/cache/restore@v4
|
- uses: actions/cache/restore@v4
|
||||||
id: cache
|
id: cache
|
||||||
@@ -70,7 +70,7 @@ jobs:
|
|||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||||
|
|
||||||
mkdir ComfyUI_windows_portable
|
mkdir ComfyUI_windows_portable
|
||||||
mv python_embeded ComfyUI_windows_portable
|
mv python_embeded ComfyUI_windows_portable
|
||||||
@@ -85,7 +85,7 @@ jobs:
|
|||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||||
|
|
||||||
cd ComfyUI_windows_portable
|
cd ComfyUI_windows_portable
|
||||||
|
|||||||
47
.github/workflows/update-api-stubs.yml
vendored
Normal file
47
.github/workflows/update-api-stubs.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
name: Generate Pydantic Stubs from api.comfy.org
|
||||||
|
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: '0 0 * * 1'
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
generate-models:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install 'datamodel-code-generator[http]'
|
||||||
|
|
||||||
|
- name: Generate API models
|
||||||
|
run: |
|
||||||
|
datamodel-codegen --use-subclass-enum --url https://api.comfy.org/openapi --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
|
||||||
|
|
||||||
|
- name: Check for changes
|
||||||
|
id: git-check
|
||||||
|
run: |
|
||||||
|
git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Create Pull Request
|
||||||
|
if: steps.git-check.outputs.changes == 'true'
|
||||||
|
uses: peter-evans/create-pull-request@v5
|
||||||
|
with:
|
||||||
|
commit-message: 'chore: update API models from OpenAPI spec'
|
||||||
|
title: 'Update API models from api.comfy.org'
|
||||||
|
body: |
|
||||||
|
This PR updates the API models based on the latest api.comfy.org OpenAPI specification.
|
||||||
|
|
||||||
|
Generated automatically by the a Github workflow.
|
||||||
|
branch: update-api-stubs
|
||||||
|
delete-branch: true
|
||||||
|
base: main
|
||||||
@@ -56,7 +56,7 @@ jobs:
|
|||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||||
|
|
||||||
mkdir ComfyUI_windows_portable_nightly_pytorch
|
mkdir ComfyUI_windows_portable_nightly_pytorch
|
||||||
mv python_embeded ComfyUI_windows_portable_nightly_pytorch
|
mv python_embeded ComfyUI_windows_portable_nightly_pytorch
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 150
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- shell: bash
|
- shell: bash
|
||||||
run: |
|
run: |
|
||||||
@@ -67,7 +67,7 @@ jobs:
|
|||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||||
|
|
||||||
mkdir ComfyUI_windows_portable
|
mkdir ComfyUI_windows_portable
|
||||||
mv python_embeded ComfyUI_windows_portable
|
mv python_embeded ComfyUI_windows_portable
|
||||||
@@ -82,7 +82,7 @@ jobs:
|
|||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
||||||
|
|
||||||
cd ComfyUI_windows_portable
|
cd ComfyUI_windows_portable
|
||||||
|
|||||||
26
CODEOWNERS
26
CODEOWNERS
@@ -5,20 +5,20 @@
|
|||||||
# Inlined the team members for now.
|
# Inlined the team members for now.
|
||||||
|
|
||||||
# Maintainers
|
# Maintainers
|
||||||
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||||
|
|
||||||
# Python web server
|
# Python web server
|
||||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||||
|
|
||||||
# Node developers
|
# Node developers
|
||||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
||||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||||
|
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||||
- Video Models
|
- Video Models
|
||||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||||
@@ -215,9 +216,9 @@ Additional discussion and help can be found [here](https://github.com/comfyanony
|
|||||||
|
|
||||||
Nvidia users should install stable pytorch using this command:
|
Nvidia users should install stable pytorch using this command:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
|
||||||
|
|
||||||
This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements.
|
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
||||||
|
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diff
|
|||||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
||||||
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||||
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||||
|
fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.")
|
||||||
|
|
||||||
fpvae_group = parser.add_mutually_exclusive_group()
|
fpvae_group = parser.add_mutually_exclusive_group()
|
||||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Comfy-specific type hinting"""
|
"""Comfy-specific type hinting"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Literal, TypedDict
|
from typing import Literal, TypedDict, Optional
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -115,6 +115,11 @@ class InputTypeOptions(TypedDict):
|
|||||||
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
||||||
tooltip: NotRequired[str]
|
tooltip: NotRequired[str]
|
||||||
"""Tooltip for the input (or widget), shown on pointer hover"""
|
"""Tooltip for the input (or widget), shown on pointer hover"""
|
||||||
|
socketless: NotRequired[bool]
|
||||||
|
"""All inputs (including widgets) have an input socket to connect links. When ``true``, if there is a widget for this input, no socket will be created.
|
||||||
|
Available from frontend v1.17.5
|
||||||
|
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548
|
||||||
|
"""
|
||||||
# class InputTypeNumber(InputTypeOptions):
|
# class InputTypeNumber(InputTypeOptions):
|
||||||
# default: float | int
|
# default: float | int
|
||||||
min: NotRequired[float]
|
min: NotRequired[float]
|
||||||
@@ -224,6 +229,8 @@ class ComfyNodeABC(ABC):
|
|||||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||||
DEPRECATED: bool
|
DEPRECATED: bool
|
||||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||||
|
API_NODE: Optional[bool]
|
||||||
|
"""Flags a node as an API node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -736,6 +736,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None, model_options={}):
|
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||||
|
model_options = model_options.copy()
|
||||||
if "global_average_pooling" not in model_options:
|
if "global_average_pooling" not in model_options:
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ class Dino2Embeddings(torch.nn.Module):
|
|||||||
def forward(self, pixel_values):
|
def forward(self, pixel_values):
|
||||||
x = self.patch_embeddings(pixel_values)
|
x = self.patch_embeddings(pixel_values)
|
||||||
# TODO: mask_token?
|
# TODO: mask_token?
|
||||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
@@ -220,6 +220,34 @@ class WanAttentionBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VaceWanAttentionBlock(WanAttentionBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cross_attn_type,
|
||||||
|
dim,
|
||||||
|
ffn_dim,
|
||||||
|
num_heads,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=False,
|
||||||
|
eps=1e-6,
|
||||||
|
block_id=0,
|
||||||
|
operation_settings={}
|
||||||
|
):
|
||||||
|
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||||
|
self.block_id = block_id
|
||||||
|
if block_id == 0:
|
||||||
|
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, c, x, **kwargs):
|
||||||
|
if self.block_id == 0:
|
||||||
|
c = self.before_proj(c) + x
|
||||||
|
c = super().forward(c, **kwargs)
|
||||||
|
c_skip = self.after_proj(c)
|
||||||
|
return c_skip, c
|
||||||
|
|
||||||
|
|
||||||
class Head(nn.Module):
|
class Head(nn.Module):
|
||||||
|
|
||||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
||||||
@@ -395,6 +423,7 @@ class WanModel(torch.nn.Module):
|
|||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
freqs=None,
|
freqs=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Forward pass through the diffusion model
|
Forward pass through the diffusion model
|
||||||
@@ -471,7 +500,7 @@ class WanModel(torch.nn.Module):
|
|||||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
|
|
||||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
r"""
|
r"""
|
||||||
@@ -496,3 +525,115 @@ class WanModel(torch.nn.Module):
|
|||||||
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
|
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
|
||||||
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
||||||
return u
|
return u
|
||||||
|
|
||||||
|
|
||||||
|
class VaceWanModel(WanModel):
|
||||||
|
r"""
|
||||||
|
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_type='vace',
|
||||||
|
patch_size=(1, 2, 2),
|
||||||
|
text_len=512,
|
||||||
|
in_dim=16,
|
||||||
|
dim=2048,
|
||||||
|
ffn_dim=8192,
|
||||||
|
freq_dim=256,
|
||||||
|
text_dim=4096,
|
||||||
|
out_dim=16,
|
||||||
|
num_heads=16,
|
||||||
|
num_layers=32,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
flf_pos_embed_token_number=None,
|
||||||
|
image_model=None,
|
||||||
|
vace_layers=None,
|
||||||
|
vace_in_dim=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||||
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
# Vace
|
||||||
|
if vace_layers is not None:
|
||||||
|
self.vace_layers = vace_layers
|
||||||
|
self.vace_in_dim = vace_in_dim
|
||||||
|
# vace blocks
|
||||||
|
self.vace_blocks = nn.ModuleList([
|
||||||
|
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, self.cross_attn_norm, self.eps, block_id=i, operation_settings=operation_settings)
|
||||||
|
for i in range(self.vace_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.vace_layers_mapping = {i: n for n, i in enumerate(range(0, self.num_layers, self.num_layers // self.vace_layers))}
|
||||||
|
# vace patch embeddings
|
||||||
|
self.vace_patch_embedding = operations.Conv3d(
|
||||||
|
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size, device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
context,
|
||||||
|
vace_context,
|
||||||
|
vace_strength=1.0,
|
||||||
|
clip_fea=None,
|
||||||
|
freqs=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# embeddings
|
||||||
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
|
grid_sizes = x.shape[2:]
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# time embeddings
|
||||||
|
e = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||||
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||||
|
|
||||||
|
# context
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
context_img_len = None
|
||||||
|
if clip_fea is not None:
|
||||||
|
if self.img_emb is not None:
|
||||||
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
context_img_len = clip_fea.shape[-2]
|
||||||
|
|
||||||
|
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
|
||||||
|
c = c.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# arguments
|
||||||
|
x_orig = x
|
||||||
|
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
|
|
||||||
|
ii = self.vace_layers_mapping.get(i, None)
|
||||||
|
if ii is not None:
|
||||||
|
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
|
x += c_skip * vace_strength
|
||||||
|
# head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
return x
|
||||||
|
|||||||
321
comfy/lora.py
321
comfy/lora.py
@@ -20,6 +20,7 @@ from __future__ import annotations
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_base
|
import comfy.model_base
|
||||||
|
import comfy.weight_adapter as weight_adapter
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -49,139 +50,12 @@ def load_lora(lora, to_load, log_missing=True):
|
|||||||
dora_scale = lora[dora_scale_name]
|
dora_scale = lora[dora_scale_name]
|
||||||
loaded_keys.add(dora_scale_name)
|
loaded_keys.add(dora_scale_name)
|
||||||
|
|
||||||
reshape_name = "{}.reshape_weight".format(x)
|
for adapter_cls in weight_adapter.adapters:
|
||||||
reshape = None
|
adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys)
|
||||||
if reshape_name in lora.keys():
|
if adapter is not None:
|
||||||
try:
|
patch_dict[to_load[x]] = adapter
|
||||||
reshape = lora[reshape_name].tolist()
|
loaded_keys.update(adapter.loaded_keys)
|
||||||
loaded_keys.add(reshape_name)
|
continue
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
regular_lora = "{}.lora_up.weight".format(x)
|
|
||||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
|
||||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
|
||||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
|
||||||
mochi_lora = "{}.lora_B".format(x)
|
|
||||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
|
||||||
A_name = None
|
|
||||||
|
|
||||||
if regular_lora in lora.keys():
|
|
||||||
A_name = regular_lora
|
|
||||||
B_name = "{}.lora_down.weight".format(x)
|
|
||||||
mid_name = "{}.lora_mid.weight".format(x)
|
|
||||||
elif diffusers_lora in lora.keys():
|
|
||||||
A_name = diffusers_lora
|
|
||||||
B_name = "{}_lora.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif diffusers2_lora in lora.keys():
|
|
||||||
A_name = diffusers2_lora
|
|
||||||
B_name = "{}.lora_A.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif diffusers3_lora in lora.keys():
|
|
||||||
A_name = diffusers3_lora
|
|
||||||
B_name = "{}.lora.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif mochi_lora in lora.keys():
|
|
||||||
A_name = mochi_lora
|
|
||||||
B_name = "{}.lora_A".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif transformers_lora in lora.keys():
|
|
||||||
A_name = transformers_lora
|
|
||||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
|
|
||||||
if A_name is not None:
|
|
||||||
mid = None
|
|
||||||
if mid_name is not None and mid_name in lora.keys():
|
|
||||||
mid = lora[mid_name]
|
|
||||||
loaded_keys.add(mid_name)
|
|
||||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
|
|
||||||
loaded_keys.add(A_name)
|
|
||||||
loaded_keys.add(B_name)
|
|
||||||
|
|
||||||
|
|
||||||
######## loha
|
|
||||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
|
||||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
|
||||||
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
|
||||||
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
|
||||||
hada_t1_name = "{}.hada_t1".format(x)
|
|
||||||
hada_t2_name = "{}.hada_t2".format(x)
|
|
||||||
if hada_w1_a_name in lora.keys():
|
|
||||||
hada_t1 = None
|
|
||||||
hada_t2 = None
|
|
||||||
if hada_t1_name in lora.keys():
|
|
||||||
hada_t1 = lora[hada_t1_name]
|
|
||||||
hada_t2 = lora[hada_t2_name]
|
|
||||||
loaded_keys.add(hada_t1_name)
|
|
||||||
loaded_keys.add(hada_t2_name)
|
|
||||||
|
|
||||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
|
||||||
loaded_keys.add(hada_w1_a_name)
|
|
||||||
loaded_keys.add(hada_w1_b_name)
|
|
||||||
loaded_keys.add(hada_w2_a_name)
|
|
||||||
loaded_keys.add(hada_w2_b_name)
|
|
||||||
|
|
||||||
|
|
||||||
######## lokr
|
|
||||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
|
||||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
|
||||||
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
|
||||||
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
|
||||||
lokr_t2_name = "{}.lokr_t2".format(x)
|
|
||||||
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
|
||||||
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
|
||||||
|
|
||||||
lokr_w1 = None
|
|
||||||
if lokr_w1_name in lora.keys():
|
|
||||||
lokr_w1 = lora[lokr_w1_name]
|
|
||||||
loaded_keys.add(lokr_w1_name)
|
|
||||||
|
|
||||||
lokr_w2 = None
|
|
||||||
if lokr_w2_name in lora.keys():
|
|
||||||
lokr_w2 = lora[lokr_w2_name]
|
|
||||||
loaded_keys.add(lokr_w2_name)
|
|
||||||
|
|
||||||
lokr_w1_a = None
|
|
||||||
if lokr_w1_a_name in lora.keys():
|
|
||||||
lokr_w1_a = lora[lokr_w1_a_name]
|
|
||||||
loaded_keys.add(lokr_w1_a_name)
|
|
||||||
|
|
||||||
lokr_w1_b = None
|
|
||||||
if lokr_w1_b_name in lora.keys():
|
|
||||||
lokr_w1_b = lora[lokr_w1_b_name]
|
|
||||||
loaded_keys.add(lokr_w1_b_name)
|
|
||||||
|
|
||||||
lokr_w2_a = None
|
|
||||||
if lokr_w2_a_name in lora.keys():
|
|
||||||
lokr_w2_a = lora[lokr_w2_a_name]
|
|
||||||
loaded_keys.add(lokr_w2_a_name)
|
|
||||||
|
|
||||||
lokr_w2_b = None
|
|
||||||
if lokr_w2_b_name in lora.keys():
|
|
||||||
lokr_w2_b = lora[lokr_w2_b_name]
|
|
||||||
loaded_keys.add(lokr_w2_b_name)
|
|
||||||
|
|
||||||
lokr_t2 = None
|
|
||||||
if lokr_t2_name in lora.keys():
|
|
||||||
lokr_t2 = lora[lokr_t2_name]
|
|
||||||
loaded_keys.add(lokr_t2_name)
|
|
||||||
|
|
||||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
|
||||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
|
||||||
|
|
||||||
#glora
|
|
||||||
a1_name = "{}.a1.weight".format(x)
|
|
||||||
a2_name = "{}.a2.weight".format(x)
|
|
||||||
b1_name = "{}.b1.weight".format(x)
|
|
||||||
b2_name = "{}.b2.weight".format(x)
|
|
||||||
if a1_name in lora:
|
|
||||||
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
|
||||||
loaded_keys.add(a1_name)
|
|
||||||
loaded_keys.add(a2_name)
|
|
||||||
loaded_keys.add(b1_name)
|
|
||||||
loaded_keys.add(b2_name)
|
|
||||||
|
|
||||||
w_norm_name = "{}.w_norm".format(x)
|
w_norm_name = "{}.w_norm".format(x)
|
||||||
b_norm_name = "{}.b_norm".format(x)
|
b_norm_name = "{}.b_norm".format(x)
|
||||||
@@ -408,26 +282,6 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
|
||||||
lora_diff *= alpha
|
|
||||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
|
||||||
weight_norm = (
|
|
||||||
weight_calc.transpose(0, 1)
|
|
||||||
.reshape(weight_calc.shape[1], -1)
|
|
||||||
.norm(dim=1, keepdim=True)
|
|
||||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
|
||||||
.transpose(0, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
|
||||||
if strength != 1.0:
|
|
||||||
weight_calc -= weight
|
|
||||||
weight += strength * (weight_calc)
|
|
||||||
else:
|
|
||||||
weight[:] = weight_calc
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Pad a tensor to a new shape with zeros.
|
Pad a tensor to a new shape with zeros.
|
||||||
@@ -482,6 +336,16 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
||||||
|
|
||||||
|
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||||
|
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
|
||||||
|
if output is None:
|
||||||
|
logging.warning("Calculate Weight Failed: {} {}".format(v.name, key))
|
||||||
|
else:
|
||||||
|
weight = output
|
||||||
|
if old_weight is not None:
|
||||||
|
weight = old_weight
|
||||||
|
continue
|
||||||
|
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
patch_type = "diff"
|
patch_type = "diff"
|
||||||
elif len(v) == 2:
|
elif len(v) == 2:
|
||||||
@@ -508,157 +372,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||||||
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
||||||
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
||||||
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
||||||
elif patch_type == "lora": #lora/locon
|
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
|
||||||
dora_scale = v[4]
|
|
||||||
reshape = v[5]
|
|
||||||
|
|
||||||
if reshape is not None:
|
|
||||||
weight = pad_tensor_to_shape(weight, reshape)
|
|
||||||
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / mat2.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
if v[3] is not None:
|
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
||||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
|
||||||
try:
|
|
||||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "lokr":
|
|
||||||
w1 = v[0]
|
|
||||||
w2 = v[1]
|
|
||||||
w1_a = v[3]
|
|
||||||
w1_b = v[4]
|
|
||||||
w2_a = v[5]
|
|
||||||
w2_b = v[6]
|
|
||||||
t2 = v[7]
|
|
||||||
dora_scale = v[8]
|
|
||||||
dim = None
|
|
||||||
|
|
||||||
if w1 is None:
|
|
||||||
dim = w1_b.shape[0]
|
|
||||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if w2 is None:
|
|
||||||
dim = w2_b.shape[0]
|
|
||||||
if t2 is None:
|
|
||||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
if v[2] is not None and dim is not None:
|
|
||||||
alpha = v[2] / dim
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "loha":
|
|
||||||
w1a = v[0]
|
|
||||||
w1b = v[1]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / w1b.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
w2a = v[3]
|
|
||||||
w2b = v[4]
|
|
||||||
dora_scale = v[7]
|
|
||||||
if v[5] is not None: #cp decomposition
|
|
||||||
t1 = v[5]
|
|
||||||
t2 = v[6]
|
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
|
||||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "glora":
|
|
||||||
dora_scale = v[5]
|
|
||||||
|
|
||||||
old_glora = False
|
|
||||||
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
|
||||||
rank = v[0].shape[0]
|
|
||||||
old_glora = True
|
|
||||||
|
|
||||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
|
||||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
old_glora = False
|
|
||||||
rank = v[1].shape[0]
|
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if v[4] is not None:
|
|
||||||
alpha = v[4] / rank
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
if old_glora:
|
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
|
||||||
else:
|
|
||||||
if weight.dim() > 2:
|
|
||||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
|
||||||
else:
|
|
||||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
|
||||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
|
||||||
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
else:
|
else:
|
||||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
|
|||||||
@@ -1043,6 +1043,37 @@ class WAN21(BaseModel):
|
|||||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class WAN21_Vace(WAN21):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
noise_shape = list(noise.shape)
|
||||||
|
vace_frames = kwargs.get("vace_frames", None)
|
||||||
|
if vace_frames is None:
|
||||||
|
noise_shape[1] = 32
|
||||||
|
vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||||
|
|
||||||
|
for i in range(0, vace_frames.shape[1], 16):
|
||||||
|
vace_frames = vace_frames.clone()
|
||||||
|
vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16])
|
||||||
|
|
||||||
|
mask = kwargs.get("vace_mask", None)
|
||||||
|
if mask is None:
|
||||||
|
noise_shape[1] = 64
|
||||||
|
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||||
|
|
||||||
|
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
|
||||||
|
|
||||||
|
vace_strength = kwargs.get("vace_strength", 1.0)
|
||||||
|
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||||
|
|||||||
@@ -317,6 +317,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["cross_attn_norm"] = True
|
dit_config["cross_attn_norm"] = True
|
||||||
dit_config["eps"] = 1e-6
|
dit_config["eps"] = 1e-6
|
||||||
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||||
|
if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["model_type"] = "vace"
|
||||||
|
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||||
|
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||||
|
else:
|
||||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "i2v"
|
dit_config["model_type"] = "i2v"
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -725,6 +725,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
if args.fp8_e5m2_unet:
|
if args.fp8_e5m2_unet:
|
||||||
return torch.float8_e5m2
|
return torch.float8_e5m2
|
||||||
|
if args.fp8_e8m0fnu_unet:
|
||||||
|
return torch.float8_e8m0fnu
|
||||||
|
|
||||||
fp8_dtype = None
|
fp8_dtype = None
|
||||||
if weight_dtype in FLOAT8_TYPES:
|
if weight_dtype in FLOAT8_TYPES:
|
||||||
|
|||||||
34
comfy/sd.py
34
comfy/sd.py
@@ -703,6 +703,7 @@ class CLIPType(Enum):
|
|||||||
COSMOS = 11
|
COSMOS = 11
|
||||||
LUMINA2 = 12
|
LUMINA2 = 12
|
||||||
WAN = 13
|
WAN = 13
|
||||||
|
HIDREAM = 14
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@@ -791,6 +792,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.SD3:
|
elif clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
@@ -811,6 +815,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||||
|
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else: #CLIPType.MOCHI
|
else: #CLIPType.MOCHI
|
||||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||||
@@ -827,10 +835,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
|
elif te_model == TEModel.LLAMA3_8:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||||
|
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
|
# clip_l
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
@@ -848,6 +864,24 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.HUNYUAN_VIDEO:
|
elif clip_type == CLIPType.HUNYUAN_VIDEO:
|
||||||
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
# Detect
|
||||||
|
hidream_dualclip_classes = []
|
||||||
|
for hidream_te in clip_data:
|
||||||
|
te_model = detect_te_model(hidream_te)
|
||||||
|
hidream_dualclip_classes.append(te_model)
|
||||||
|
|
||||||
|
clip_l = TEModel.CLIP_L in hidream_dualclip_classes
|
||||||
|
clip_g = TEModel.CLIP_G in hidream_dualclip_classes
|
||||||
|
t5 = TEModel.T5_XXL in hidream_dualclip_classes
|
||||||
|
llama = TEModel.LLAMA3_8 in hidream_dualclip_classes
|
||||||
|
|
||||||
|
# Initialize t5xxl_detect and llama_detect kwargs if needed
|
||||||
|
t5_kwargs = t5xxl_detect(clip_data) if t5 else {}
|
||||||
|
llama_kwargs = llama_detect(clip_data) if llama else {}
|
||||||
|
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
|
|||||||
@@ -987,6 +987,16 @@ class WAN21_FunControl2V(WAN21_T2V):
|
|||||||
out = model_base.WAN21(self, image_to_video=False, device=device)
|
out = model_base.WAN21(self, image_to_video=False, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN21_Vace(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "vace",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@@ -1055,6 +1065,6 @@ class HiDream(supported_models_base.BASE):
|
|||||||
return None # TODO
|
return None # TODO
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@@ -109,14 +109,18 @@ class HiDreamTEModel(torch.nn.Module):
|
|||||||
if self.t5xxl is not None:
|
if self.t5xxl is not None:
|
||||||
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||||
t5_out, t5_pooled = t5_output[:2]
|
t5_out, t5_pooled = t5_output[:2]
|
||||||
|
else:
|
||||||
|
t5_out = None
|
||||||
|
|
||||||
if self.llama is not None:
|
if self.llama is not None:
|
||||||
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
|
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||||
ll_out, ll_pooled = ll_output[:2]
|
ll_out, ll_pooled = ll_output[:2]
|
||||||
ll_out = ll_out[:, 1:]
|
ll_out = ll_out[:, 1:]
|
||||||
|
else:
|
||||||
|
ll_out = None
|
||||||
|
|
||||||
if t5_out is None:
|
if t5_out is None:
|
||||||
t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device())
|
t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
if ll_out is None:
|
if ll_out is None:
|
||||||
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
class SPieceTokenizer:
|
class SPieceTokenizer:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -15,6 +16,8 @@ class SPieceTokenizer:
|
|||||||
if isinstance(tokenizer_path, bytes):
|
if isinstance(tokenizer_path, bytes):
|
||||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||||
else:
|
else:
|
||||||
|
if not os.path.isfile(tokenizer_path):
|
||||||
|
raise ValueError("invalid tokenizer")
|
||||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||||
|
|
||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
|
|||||||
17
comfy/weight_adapter/__init__.py
Normal file
17
comfy/weight_adapter/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from .base import WeightAdapterBase
|
||||||
|
from .lora import LoRAAdapter
|
||||||
|
from .loha import LoHaAdapter
|
||||||
|
from .lokr import LoKrAdapter
|
||||||
|
from .glora import GLoRAAdapter
|
||||||
|
from .oft import OFTAdapter
|
||||||
|
from .boft import BOFTAdapter
|
||||||
|
|
||||||
|
|
||||||
|
adapters: list[type[WeightAdapterBase]] = [
|
||||||
|
LoRAAdapter,
|
||||||
|
LoHaAdapter,
|
||||||
|
LoKrAdapter,
|
||||||
|
GLoRAAdapter,
|
||||||
|
OFTAdapter,
|
||||||
|
BOFTAdapter,
|
||||||
|
]
|
||||||
104
comfy/weight_adapter/base.py
Normal file
104
comfy/weight_adapter/base.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
|
class WeightAdapterBase:
|
||||||
|
name: str
|
||||||
|
loaded_keys: set[str]
|
||||||
|
weights: list[torch.Tensor]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def to_train(self) -> "WeightAdapterTrainBase":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class WeightAdapterTrainBase(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# [TODO] Collaborate with LoRA training PR #7032
|
||||||
|
|
||||||
|
|
||||||
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||||
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
|
lora_diff *= alpha
|
||||||
|
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||||
|
|
||||||
|
wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0]
|
||||||
|
if wd_on_output_axis:
|
||||||
|
weight_norm = (
|
||||||
|
weight.reshape(weight.shape[0], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weight_norm = (
|
||||||
|
weight_calc.transpose(0, 1)
|
||||||
|
.reshape(weight_calc.shape[1], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
|
||||||
|
|
||||||
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||||
|
if strength != 1.0:
|
||||||
|
weight_calc -= weight
|
||||||
|
weight += strength * (weight_calc)
|
||||||
|
else:
|
||||||
|
weight[:] = weight_calc
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pad a tensor to a new shape with zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The original tensor to be padded.
|
||||||
|
new_shape (List[int]): The desired shape of the padded tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the new shape is smaller than the original tensor in any dimension,
|
||||||
|
the original tensor will be truncated in that dimension.
|
||||||
|
"""
|
||||||
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||||
|
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||||
|
|
||||||
|
if len(new_shape) != len(tensor.shape):
|
||||||
|
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||||
|
|
||||||
|
# Create a new tensor filled with zeros
|
||||||
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
# Create slicing tuples for both tensors
|
||||||
|
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
|
||||||
|
# Copy the original tensor into the new tensor
|
||||||
|
padded_tensor[new_slices] = tensor[orig_slices]
|
||||||
|
|
||||||
|
return padded_tensor
|
||||||
115
comfy/weight_adapter/boft.py
Normal file
115
comfy/weight_adapter/boft.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class BOFTAdapter(WeightAdapterBase):
|
||||||
|
name = "boft"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["BOFTAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
blocks_name = "{}.boft_blocks".format(x)
|
||||||
|
rescale_name = "{}.rescale".format(x)
|
||||||
|
|
||||||
|
blocks = None
|
||||||
|
if blocks_name in lora.keys():
|
||||||
|
blocks = lora[blocks_name]
|
||||||
|
if blocks.ndim == 4:
|
||||||
|
loaded_keys.add(blocks_name)
|
||||||
|
|
||||||
|
rescale = None
|
||||||
|
if rescale_name in lora.keys():
|
||||||
|
rescale = lora[rescale_name]
|
||||||
|
loaded_keys.add(rescale_name)
|
||||||
|
|
||||||
|
if blocks is not None:
|
||||||
|
weights = (blocks, rescale, alpha, dora_scale)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
blocks = v[0]
|
||||||
|
rescale = v[1]
|
||||||
|
alpha = v[2]
|
||||||
|
dora_scale = v[3]
|
||||||
|
|
||||||
|
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||||
|
if rescale is not None:
|
||||||
|
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
boft_m, block_num, boft_b, *_ = blocks.shape
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get r
|
||||||
|
I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype)
|
||||||
|
# for Q = -Q^T
|
||||||
|
q = blocks - blocks.transpose(1, 2)
|
||||||
|
normed_q = q
|
||||||
|
if alpha > 0: # alpha in boft/bboft is for constraint
|
||||||
|
q_norm = torch.norm(q) + 1e-8
|
||||||
|
if q_norm > alpha:
|
||||||
|
normed_q = q * alpha / q_norm
|
||||||
|
# use float() to prevent unsupported type in .inverse()
|
||||||
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
|
r = r.to(original_weight)
|
||||||
|
|
||||||
|
inp = org = original_weight
|
||||||
|
|
||||||
|
r_b = boft_b//2
|
||||||
|
for i in range(boft_m):
|
||||||
|
bi = r[i]
|
||||||
|
g = 2
|
||||||
|
k = 2**i * r_b
|
||||||
|
if strength != 1:
|
||||||
|
bi = bi * strength + (1-strength) * I
|
||||||
|
inp = (
|
||||||
|
inp.unflatten(-1, (-1, g, k))
|
||||||
|
.transpose(-2, -1)
|
||||||
|
.flatten(-3)
|
||||||
|
.unflatten(-1, (-1, boft_b))
|
||||||
|
)
|
||||||
|
inp = torch.einsum("b n m, b n ... -> b m ...", inp, bi)
|
||||||
|
inp = (
|
||||||
|
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
|
||||||
|
)
|
||||||
|
|
||||||
|
if rescale is not None:
|
||||||
|
inp = inp * rescale
|
||||||
|
|
||||||
|
lora_diff = inp - org
|
||||||
|
lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
||||||
93
comfy/weight_adapter/glora.py
Normal file
93
comfy/weight_adapter/glora.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class GLoRAAdapter(WeightAdapterBase):
|
||||||
|
name = "glora"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["GLoRAAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
a1_name = "{}.a1.weight".format(x)
|
||||||
|
a2_name = "{}.a2.weight".format(x)
|
||||||
|
b1_name = "{}.b1.weight".format(x)
|
||||||
|
b2_name = "{}.b2.weight".format(x)
|
||||||
|
if a1_name in lora:
|
||||||
|
weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
|
||||||
|
loaded_keys.add(a1_name)
|
||||||
|
loaded_keys.add(a2_name)
|
||||||
|
loaded_keys.add(b1_name)
|
||||||
|
loaded_keys.add(b2_name)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
old_glora = False
|
||||||
|
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||||
|
rank = v[0].shape[0]
|
||||||
|
old_glora = True
|
||||||
|
|
||||||
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||||
|
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
old_glora = False
|
||||||
|
rank = v[1].shape[0]
|
||||||
|
|
||||||
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / rank
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if old_glora:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||||
|
else:
|
||||||
|
if weight.dim() > 2:
|
||||||
|
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
else:
|
||||||
|
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||||
|
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
||||||
100
comfy/weight_adapter/loha.py
Normal file
100
comfy/weight_adapter/loha.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class LoHaAdapter(WeightAdapterBase):
|
||||||
|
name = "loha"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["LoHaAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
|
||||||
|
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||||
|
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||||
|
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||||
|
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||||
|
hada_t1_name = "{}.hada_t1".format(x)
|
||||||
|
hada_t2_name = "{}.hada_t2".format(x)
|
||||||
|
if hada_w1_a_name in lora.keys():
|
||||||
|
hada_t1 = None
|
||||||
|
hada_t2 = None
|
||||||
|
if hada_t1_name in lora.keys():
|
||||||
|
hada_t1 = lora[hada_t1_name]
|
||||||
|
hada_t2 = lora[hada_t2_name]
|
||||||
|
loaded_keys.add(hada_t1_name)
|
||||||
|
loaded_keys.add(hada_t2_name)
|
||||||
|
|
||||||
|
weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)
|
||||||
|
loaded_keys.add(hada_w1_a_name)
|
||||||
|
loaded_keys.add(hada_w1_b_name)
|
||||||
|
loaded_keys.add(hada_w2_a_name)
|
||||||
|
loaded_keys.add(hada_w2_b_name)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
w1a = v[0]
|
||||||
|
w1b = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / w1b.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
w2a = v[3]
|
||||||
|
w2b = v[4]
|
||||||
|
dora_scale = v[7]
|
||||||
|
if v[5] is not None: #cp decomposition
|
||||||
|
t1 = v[5]
|
||||||
|
t2 = v[6]
|
||||||
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||||
|
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
||||||
133
comfy/weight_adapter/lokr.py
Normal file
133
comfy/weight_adapter/lokr.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class LoKrAdapter(WeightAdapterBase):
|
||||||
|
name = "lokr"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["LoKrAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||||
|
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||||
|
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||||
|
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||||
|
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||||
|
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||||
|
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||||
|
|
||||||
|
lokr_w1 = None
|
||||||
|
if lokr_w1_name in lora.keys():
|
||||||
|
lokr_w1 = lora[lokr_w1_name]
|
||||||
|
loaded_keys.add(lokr_w1_name)
|
||||||
|
|
||||||
|
lokr_w2 = None
|
||||||
|
if lokr_w2_name in lora.keys():
|
||||||
|
lokr_w2 = lora[lokr_w2_name]
|
||||||
|
loaded_keys.add(lokr_w2_name)
|
||||||
|
|
||||||
|
lokr_w1_a = None
|
||||||
|
if lokr_w1_a_name in lora.keys():
|
||||||
|
lokr_w1_a = lora[lokr_w1_a_name]
|
||||||
|
loaded_keys.add(lokr_w1_a_name)
|
||||||
|
|
||||||
|
lokr_w1_b = None
|
||||||
|
if lokr_w1_b_name in lora.keys():
|
||||||
|
lokr_w1_b = lora[lokr_w1_b_name]
|
||||||
|
loaded_keys.add(lokr_w1_b_name)
|
||||||
|
|
||||||
|
lokr_w2_a = None
|
||||||
|
if lokr_w2_a_name in lora.keys():
|
||||||
|
lokr_w2_a = lora[lokr_w2_a_name]
|
||||||
|
loaded_keys.add(lokr_w2_a_name)
|
||||||
|
|
||||||
|
lokr_w2_b = None
|
||||||
|
if lokr_w2_b_name in lora.keys():
|
||||||
|
lokr_w2_b = lora[lokr_w2_b_name]
|
||||||
|
loaded_keys.add(lokr_w2_b_name)
|
||||||
|
|
||||||
|
lokr_t2 = None
|
||||||
|
if lokr_t2_name in lora.keys():
|
||||||
|
lokr_t2 = lora[lokr_t2_name]
|
||||||
|
loaded_keys.add(lokr_t2_name)
|
||||||
|
|
||||||
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||||
|
weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
w1 = v[0]
|
||||||
|
w2 = v[1]
|
||||||
|
w1_a = v[3]
|
||||||
|
w1_b = v[4]
|
||||||
|
w2_a = v[5]
|
||||||
|
w2_b = v[6]
|
||||||
|
t2 = v[7]
|
||||||
|
dora_scale = v[8]
|
||||||
|
dim = None
|
||||||
|
|
||||||
|
if w1 is None:
|
||||||
|
dim = w1_b.shape[0]
|
||||||
|
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if w2 is None:
|
||||||
|
dim = w2_b.shape[0]
|
||||||
|
if t2 is None:
|
||||||
|
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
if v[2] is not None and dim is not None:
|
||||||
|
alpha = v[2] / dim
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
||||||
142
comfy/weight_adapter/lora.py
Normal file
142
comfy/weight_adapter/lora.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAAdapter(WeightAdapterBase):
|
||||||
|
name = "lora"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["LoRAAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
|
||||||
|
reshape_name = "{}.reshape_weight".format(x)
|
||||||
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||||
|
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||||
|
mochi_lora = "{}.lora_B".format(x)
|
||||||
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
|
A_name = None
|
||||||
|
|
||||||
|
if regular_lora in lora.keys():
|
||||||
|
A_name = regular_lora
|
||||||
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
|
mid_name = "{}.lora_mid.weight".format(x)
|
||||||
|
elif diffusers_lora in lora.keys():
|
||||||
|
A_name = diffusers_lora
|
||||||
|
B_name = "{}_lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif diffusers2_lora in lora.keys():
|
||||||
|
A_name = diffusers2_lora
|
||||||
|
B_name = "{}.lora_A.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif diffusers3_lora in lora.keys():
|
||||||
|
A_name = diffusers3_lora
|
||||||
|
B_name = "{}.lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif mochi_lora in lora.keys():
|
||||||
|
A_name = mochi_lora
|
||||||
|
B_name = "{}.lora_A".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif transformers_lora in lora.keys():
|
||||||
|
A_name = transformers_lora
|
||||||
|
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
|
||||||
|
if A_name is not None:
|
||||||
|
mid = None
|
||||||
|
if mid_name is not None and mid_name in lora.keys():
|
||||||
|
mid = lora[mid_name]
|
||||||
|
loaded_keys.add(mid_name)
|
||||||
|
reshape = None
|
||||||
|
if reshape_name in lora.keys():
|
||||||
|
try:
|
||||||
|
reshape = lora[reshape_name].tolist()
|
||||||
|
loaded_keys.add(reshape_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)
|
||||||
|
loaded_keys.add(A_name)
|
||||||
|
loaded_keys.add(B_name)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
mat1 = comfy.model_management.cast_to_device(
|
||||||
|
v[0], weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
mat2 = comfy.model_management.cast_to_device(
|
||||||
|
v[1], weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
dora_scale = v[4]
|
||||||
|
reshape = v[5]
|
||||||
|
|
||||||
|
if reshape is not None:
|
||||||
|
weight = pad_tensor_to_shape(weight, reshape)
|
||||||
|
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / mat2.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
if v[3] is not None:
|
||||||
|
# locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
|
mat3 = comfy.model_management.cast_to_device(
|
||||||
|
v[3], weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
|
mat2 = (
|
||||||
|
torch.mm(
|
||||||
|
mat2.transpose(0, 1).flatten(start_dim=1),
|
||||||
|
mat3.transpose(0, 1).flatten(start_dim=1),
|
||||||
|
)
|
||||||
|
.reshape(final_shape)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
lora_diff = torch.mm(
|
||||||
|
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
|
||||||
|
).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(
|
||||||
|
dora_scale,
|
||||||
|
weight,
|
||||||
|
lora_diff,
|
||||||
|
alpha,
|
||||||
|
strength,
|
||||||
|
intermediate_dtype,
|
||||||
|
function,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
||||||
94
comfy/weight_adapter/oft.py
Normal file
94
comfy/weight_adapter/oft.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class OFTAdapter(WeightAdapterBase):
|
||||||
|
name = "oft"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["OFTAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
blocks_name = "{}.oft_blocks".format(x)
|
||||||
|
rescale_name = "{}.rescale".format(x)
|
||||||
|
|
||||||
|
blocks = None
|
||||||
|
if blocks_name in lora.keys():
|
||||||
|
blocks = lora[blocks_name]
|
||||||
|
if blocks.ndim == 3:
|
||||||
|
loaded_keys.add(blocks_name)
|
||||||
|
|
||||||
|
rescale = None
|
||||||
|
if rescale_name in lora.keys():
|
||||||
|
rescale = lora[rescale_name]
|
||||||
|
loaded_keys.add(rescale_name)
|
||||||
|
|
||||||
|
if blocks is not None:
|
||||||
|
weights = (blocks, rescale, alpha, dora_scale)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
blocks = v[0]
|
||||||
|
rescale = v[1]
|
||||||
|
alpha = v[2]
|
||||||
|
dora_scale = v[3]
|
||||||
|
|
||||||
|
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||||
|
if rescale is not None:
|
||||||
|
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
block_num, block_size, *_ = blocks.shape
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get r
|
||||||
|
I = torch.eye(block_size, device=blocks.device, dtype=blocks.dtype)
|
||||||
|
# for Q = -Q^T
|
||||||
|
q = blocks - blocks.transpose(1, 2)
|
||||||
|
normed_q = q
|
||||||
|
if alpha > 0: # alpha in oft/boft is for constraint
|
||||||
|
q_norm = torch.norm(q) + 1e-8
|
||||||
|
if q_norm > alpha:
|
||||||
|
normed_q = q * alpha / q_norm
|
||||||
|
# use float() to prevent unsupported type in .inverse()
|
||||||
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
|
r = r.to(original_weight)
|
||||||
|
lora_diff = torch.einsum(
|
||||||
|
"k n m, k n ... -> k m ...",
|
||||||
|
(r * strength) - strength * I,
|
||||||
|
original_weight,
|
||||||
|
)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
||||||
0
comfy_api_nodes/__init__.py
Normal file
0
comfy_api_nodes/__init__.py
Normal file
17
comfy_api_nodes/apis/PixverseController.py
Normal file
17
comfy_api_nodes/apis/PixverseController.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# generated by datamodel-codegen:
|
||||||
|
# filename: https://api.comfy.org/openapi
|
||||||
|
# timestamp: 2025-04-23T15:56:33+00:00
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from . import PixverseDto
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseData(BaseModel):
|
||||||
|
ErrCode: Optional[int] = None
|
||||||
|
ErrMsg: Optional[str] = None
|
||||||
|
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None
|
||||||
57
comfy_api_nodes/apis/PixverseDto.py
Normal file
57
comfy_api_nodes/apis/PixverseDto.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
# generated by datamodel-codegen:
|
||||||
|
# filename: https://api.comfy.org/openapi
|
||||||
|
# timestamp: 2025-04-23T15:56:33+00:00
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, constr
|
||||||
|
|
||||||
|
|
||||||
|
class V2OpenAPII2VResp(BaseModel):
|
||||||
|
video_id: Optional[int] = Field(None, description='Video_id')
|
||||||
|
|
||||||
|
|
||||||
|
class V2OpenAPIT2VReq(BaseModel):
|
||||||
|
aspect_ratio: str = Field(
|
||||||
|
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
|
||||||
|
)
|
||||||
|
duration: int = Field(
|
||||||
|
...,
|
||||||
|
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
|
||||||
|
examples=[5],
|
||||||
|
)
|
||||||
|
model: str = Field(
|
||||||
|
..., description='Model version (only supports v3.5)', examples=['v3.5']
|
||||||
|
)
|
||||||
|
motion_mode: Optional[str] = Field(
|
||||||
|
'normal',
|
||||||
|
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
|
||||||
|
examples=['normal'],
|
||||||
|
)
|
||||||
|
negative_prompt: Optional[constr(max_length=2048)] = Field(
|
||||||
|
None, description='Negative prompt\n'
|
||||||
|
)
|
||||||
|
prompt: constr(max_length=2048) = Field(..., description='Prompt')
|
||||||
|
quality: str = Field(
|
||||||
|
...,
|
||||||
|
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
|
||||||
|
examples=['540p'],
|
||||||
|
)
|
||||||
|
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
|
||||||
|
style: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
|
||||||
|
examples=['anime'],
|
||||||
|
)
|
||||||
|
template_id: Optional[int] = Field(
|
||||||
|
None,
|
||||||
|
description='Template ID (template_id must be activated before use)',
|
||||||
|
examples=[302325299692608],
|
||||||
|
)
|
||||||
|
water_mark: Optional[bool] = Field(
|
||||||
|
False,
|
||||||
|
description='Watermark (true: add watermark, false: no watermark)',
|
||||||
|
examples=[False],
|
||||||
|
)
|
||||||
422
comfy_api_nodes/apis/__init__.py
Normal file
422
comfy_api_nodes/apis/__init__.py
Normal file
@@ -0,0 +1,422 @@
|
|||||||
|
# generated by datamodel-codegen:
|
||||||
|
# filename: https://api.comfy.org/openapi
|
||||||
|
# timestamp: 2025-04-23T15:56:33+00:00
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import AnyUrl, BaseModel, Field, confloat, conint
|
||||||
|
|
||||||
|
class Customer(BaseModel):
|
||||||
|
createdAt: Optional[datetime] = Field(
|
||||||
|
None, description='The date and time the user was created'
|
||||||
|
)
|
||||||
|
email: Optional[str] = Field(None, description='The email address for this user')
|
||||||
|
id: str = Field(..., description='The firebase UID of the user')
|
||||||
|
name: Optional[str] = Field(None, description='The name for this user')
|
||||||
|
updatedAt: Optional[datetime] = Field(
|
||||||
|
None, description='The date and time the user was last updated'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Error(BaseModel):
|
||||||
|
details: Optional[List[str]] = Field(
|
||||||
|
None,
|
||||||
|
description='Optional detailed information about the error or hints for resolving it.',
|
||||||
|
)
|
||||||
|
message: Optional[str] = Field(
|
||||||
|
None, description='A clear and concise description of the error.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(BaseModel):
|
||||||
|
error: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
class ImageRequest(BaseModel):
|
||||||
|
aspect_ratio: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.",
|
||||||
|
)
|
||||||
|
color_palette: Optional[Dict[str, Any]] = Field(
|
||||||
|
None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.'
|
||||||
|
)
|
||||||
|
magic_prompt_option: Optional[str] = Field(
|
||||||
|
None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')."
|
||||||
|
)
|
||||||
|
model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')")
|
||||||
|
negative_prompt: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.',
|
||||||
|
)
|
||||||
|
num_images: Optional[conint(ge=1, le=8)] = Field(
|
||||||
|
1, description='Optional. Number of images to generate (1-8). Defaults to 1.'
|
||||||
|
)
|
||||||
|
prompt: str = Field(
|
||||||
|
..., description='Required. The prompt to use to generate the image.'
|
||||||
|
)
|
||||||
|
resolution: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.",
|
||||||
|
)
|
||||||
|
seed: Optional[conint(ge=0, le=2147483647)] = Field(
|
||||||
|
None, description='Optional. A number between 0 and 2147483647.'
|
||||||
|
)
|
||||||
|
style_type: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Datum(BaseModel):
|
||||||
|
is_image_safe: Optional[bool] = Field(
|
||||||
|
None, description='Indicates whether the image is considered safe.'
|
||||||
|
)
|
||||||
|
prompt: Optional[str] = Field(
|
||||||
|
None, description='The prompt used to generate this image.'
|
||||||
|
)
|
||||||
|
resolution: Optional[str] = Field(
|
||||||
|
None, description="The resolution of the generated image (e.g., '1024x1024')."
|
||||||
|
)
|
||||||
|
seed: Optional[int] = Field(
|
||||||
|
None, description='The seed value used for this generation.'
|
||||||
|
)
|
||||||
|
style_type: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="The style type used for generation (e.g., 'REALISTIC', 'ANIME').",
|
||||||
|
)
|
||||||
|
url: Optional[str] = Field(None, description='URL to the generated image.')
|
||||||
|
|
||||||
|
|
||||||
|
class Code(Enum):
|
||||||
|
int_1100 = 1100
|
||||||
|
int_1101 = 1101
|
||||||
|
int_1102 = 1102
|
||||||
|
int_1103 = 1103
|
||||||
|
|
||||||
|
|
||||||
|
class Code1(Enum):
|
||||||
|
int_1000 = 1000
|
||||||
|
int_1001 = 1001
|
||||||
|
int_1002 = 1002
|
||||||
|
int_1003 = 1003
|
||||||
|
int_1004 = 1004
|
||||||
|
|
||||||
|
|
||||||
|
class AspectRatio(str, Enum):
|
||||||
|
field_16_9 = '16:9'
|
||||||
|
field_9_16 = '9:16'
|
||||||
|
field_1_1 = '1:1'
|
||||||
|
|
||||||
|
|
||||||
|
class Config(BaseModel):
|
||||||
|
horizontal: Optional[confloat(ge=-10.0, le=10.0)] = None
|
||||||
|
pan: Optional[confloat(ge=-10.0, le=10.0)] = None
|
||||||
|
roll: Optional[confloat(ge=-10.0, le=10.0)] = None
|
||||||
|
tilt: Optional[confloat(ge=-10.0, le=10.0)] = None
|
||||||
|
vertical: Optional[confloat(ge=-10.0, le=10.0)] = None
|
||||||
|
zoom: Optional[confloat(ge=-10.0, le=10.0)] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Type(str, Enum):
|
||||||
|
simple = 'simple'
|
||||||
|
down_back = 'down_back'
|
||||||
|
forward_up = 'forward_up'
|
||||||
|
right_turn_forward = 'right_turn_forward'
|
||||||
|
left_turn_forward = 'left_turn_forward'
|
||||||
|
|
||||||
|
|
||||||
|
class CameraControl(BaseModel):
|
||||||
|
config: Optional[Config] = None
|
||||||
|
type: Optional[Type] = Field(None, description='Predefined camera movements type')
|
||||||
|
|
||||||
|
|
||||||
|
class Duration(str, Enum):
|
||||||
|
field_5 = 5
|
||||||
|
field_10 = 10
|
||||||
|
|
||||||
|
|
||||||
|
class Mode(str, Enum):
|
||||||
|
std = 'std'
|
||||||
|
pro = 'pro'
|
||||||
|
|
||||||
|
|
||||||
|
class TaskInfo(BaseModel):
|
||||||
|
external_task_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Video(BaseModel):
|
||||||
|
duration: Optional[str] = Field(None, description='Total video duration')
|
||||||
|
id: Optional[str] = Field(None, description='Generated video ID')
|
||||||
|
url: Optional[AnyUrl] = Field(None, description='URL for generated video')
|
||||||
|
|
||||||
|
|
||||||
|
class TaskResult(BaseModel):
|
||||||
|
videos: Optional[List[Video]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(str, Enum):
|
||||||
|
submitted = 'submitted'
|
||||||
|
processing = 'processing'
|
||||||
|
succeed = 'succeed'
|
||||||
|
failed = 'failed'
|
||||||
|
|
||||||
|
|
||||||
|
class Data(BaseModel):
|
||||||
|
created_at: Optional[int] = Field(None, description='Task creation time')
|
||||||
|
task_id: Optional[str] = Field(None, description='Task ID')
|
||||||
|
task_info: Optional[TaskInfo] = None
|
||||||
|
task_result: Optional[TaskResult] = None
|
||||||
|
task_status: Optional[TaskStatus] = None
|
||||||
|
updated_at: Optional[int] = Field(None, description='Task update time')
|
||||||
|
|
||||||
|
|
||||||
|
class AspectRatio1(str, Enum):
|
||||||
|
field_16_9 = '16:9'
|
||||||
|
field_9_16 = '9:16'
|
||||||
|
field_1_1 = '1:1'
|
||||||
|
field_4_3 = '4:3'
|
||||||
|
field_3_4 = '3:4'
|
||||||
|
field_3_2 = '3:2'
|
||||||
|
field_2_3 = '2:3'
|
||||||
|
field_21_9 = '21:9'
|
||||||
|
|
||||||
|
|
||||||
|
class ImageReference(str, Enum):
|
||||||
|
subject = 'subject'
|
||||||
|
face = 'face'
|
||||||
|
|
||||||
|
|
||||||
|
class Image(BaseModel):
|
||||||
|
index: Optional[int] = Field(None, description='Image Number (0-9)')
|
||||||
|
url: Optional[AnyUrl] = Field(None, description='URL for generated image')
|
||||||
|
|
||||||
|
|
||||||
|
class TaskResult1(BaseModel):
|
||||||
|
images: Optional[List[Image]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Data1(BaseModel):
|
||||||
|
created_at: Optional[int] = Field(None, description='Task creation time')
|
||||||
|
task_id: Optional[str] = Field(None, description='Task ID')
|
||||||
|
task_result: Optional[TaskResult1] = None
|
||||||
|
task_status: Optional[TaskStatus] = None
|
||||||
|
task_status_msg: Optional[str] = Field(None, description='Task status information')
|
||||||
|
updated_at: Optional[int] = Field(None, description='Task update time')
|
||||||
|
|
||||||
|
|
||||||
|
class AspectRatio2(str, Enum):
|
||||||
|
field_16_9 = '16:9'
|
||||||
|
field_9_16 = '9:16'
|
||||||
|
field_1_1 = '1:1'
|
||||||
|
|
||||||
|
|
||||||
|
class CameraControl1(BaseModel):
|
||||||
|
config: Optional[Config] = None
|
||||||
|
type: Optional[Type] = Field(None, description='Predefined camera movements type')
|
||||||
|
|
||||||
|
|
||||||
|
class ModelName2(str, Enum):
|
||||||
|
kling_v1 = 'kling-v1'
|
||||||
|
kling_v1_6 = 'kling-v1-6'
|
||||||
|
|
||||||
|
|
||||||
|
class TaskResult2(BaseModel):
|
||||||
|
videos: Optional[List[Video]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Data2(BaseModel):
|
||||||
|
created_at: Optional[int] = Field(None, description='Task creation time')
|
||||||
|
task_id: Optional[str] = Field(None, description='Task ID')
|
||||||
|
task_info: Optional[TaskInfo] = None
|
||||||
|
task_result: Optional[TaskResult2] = None
|
||||||
|
task_status: Optional[TaskStatus] = None
|
||||||
|
updated_at: Optional[int] = Field(None, description='Task update time')
|
||||||
|
|
||||||
|
|
||||||
|
class Code2(Enum):
|
||||||
|
int_1200 = 1200
|
||||||
|
int_1201 = 1201
|
||||||
|
int_1202 = 1202
|
||||||
|
int_1203 = 1203
|
||||||
|
|
||||||
|
|
||||||
|
class ResourcePackType(str, Enum):
|
||||||
|
decreasing_total = 'decreasing_total'
|
||||||
|
constant_period = 'constant_period'
|
||||||
|
|
||||||
|
|
||||||
|
class Status(str, Enum):
|
||||||
|
toBeOnline = 'toBeOnline'
|
||||||
|
online = 'online'
|
||||||
|
expired = 'expired'
|
||||||
|
runOut = 'runOut'
|
||||||
|
|
||||||
|
|
||||||
|
class ResourcePackSubscribeInfo(BaseModel):
|
||||||
|
effective_time: Optional[int] = Field(
|
||||||
|
None, description='Effective time, Unix timestamp in ms'
|
||||||
|
)
|
||||||
|
invalid_time: Optional[int] = Field(
|
||||||
|
None, description='Expiration time, Unix timestamp in ms'
|
||||||
|
)
|
||||||
|
purchase_time: Optional[int] = Field(
|
||||||
|
None, description='Purchase time, Unix timestamp in ms'
|
||||||
|
)
|
||||||
|
remaining_quantity: Optional[float] = Field(
|
||||||
|
None, description='Remaining quantity (updated with a 12-hour delay)'
|
||||||
|
)
|
||||||
|
resource_pack_id: Optional[str] = Field(None, description='Resource package ID')
|
||||||
|
resource_pack_name: Optional[str] = Field(None, description='Resource package name')
|
||||||
|
resource_pack_type: Optional[ResourcePackType] = Field(
|
||||||
|
None,
|
||||||
|
description='Resource package type (decreasing_total=decreasing total, constant_period=constant periodicity)',
|
||||||
|
)
|
||||||
|
status: Optional[Status] = Field(None, description='Resource Package Status')
|
||||||
|
total_quantity: Optional[float] = Field(None, description='Total quantity')
|
||||||
|
|
||||||
|
class Background(str, Enum):
|
||||||
|
transparent = 'transparent'
|
||||||
|
opaque = 'opaque'
|
||||||
|
|
||||||
|
|
||||||
|
class Moderation(str, Enum):
|
||||||
|
low = 'low'
|
||||||
|
auto = 'auto'
|
||||||
|
|
||||||
|
|
||||||
|
class OutputFormat(str, Enum):
|
||||||
|
png = 'png'
|
||||||
|
webp = 'webp'
|
||||||
|
jpeg = 'jpeg'
|
||||||
|
|
||||||
|
|
||||||
|
class Quality(str, Enum):
|
||||||
|
low = 'low'
|
||||||
|
medium = 'medium'
|
||||||
|
high = 'high'
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIImageEditRequest(BaseModel):
|
||||||
|
background: Optional[str] = Field(
|
||||||
|
None, description='Background transparency', examples=['opaque']
|
||||||
|
)
|
||||||
|
model: str = Field(
|
||||||
|
..., description='The model to use for image editing', examples=['gpt-image-1']
|
||||||
|
)
|
||||||
|
moderation: Optional[Moderation] = Field(
|
||||||
|
None, description='Content moderation setting', examples=['auto']
|
||||||
|
)
|
||||||
|
n: Optional[int] = Field(
|
||||||
|
None, description='The number of images to generate', examples=[1]
|
||||||
|
)
|
||||||
|
output_compression: Optional[int] = Field(
|
||||||
|
None, description='Compression level for JPEG or WebP (0-100)', examples=[100]
|
||||||
|
)
|
||||||
|
output_format: Optional[OutputFormat] = Field(
|
||||||
|
None, description='Format of the output image', examples=['png']
|
||||||
|
)
|
||||||
|
prompt: str = Field(
|
||||||
|
...,
|
||||||
|
description='A text description of the desired edit',
|
||||||
|
examples=['Give the rocketship rainbow coloring'],
|
||||||
|
)
|
||||||
|
quality: Optional[str] = Field(
|
||||||
|
None, description='The quality of the edited image', examples=['low']
|
||||||
|
)
|
||||||
|
size: Optional[str] = Field(
|
||||||
|
None, description='Size of the output image', examples=['1024x1024']
|
||||||
|
)
|
||||||
|
user: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='A unique identifier for end-user monitoring',
|
||||||
|
examples=['user-1234'],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Quality1(str, Enum):
|
||||||
|
low = 'low'
|
||||||
|
medium = 'medium'
|
||||||
|
high = 'high'
|
||||||
|
standard = 'standard'
|
||||||
|
hd = 'hd'
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormat(str, Enum):
|
||||||
|
url = 'url'
|
||||||
|
b64_json = 'b64_json'
|
||||||
|
|
||||||
|
|
||||||
|
class Style(str, Enum):
|
||||||
|
vivid = 'vivid'
|
||||||
|
natural = 'natural'
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIImageGenerationRequest(BaseModel):
|
||||||
|
background: Optional[Background] = Field(
|
||||||
|
None, description='Background transparency', examples=['opaque']
|
||||||
|
)
|
||||||
|
model: Optional[str] = Field(
|
||||||
|
None, description='The model to use for image generation', examples=['dall-e-3']
|
||||||
|
)
|
||||||
|
moderation: Optional[Moderation] = Field(
|
||||||
|
None, description='Content moderation setting', examples=['auto']
|
||||||
|
)
|
||||||
|
n: Optional[int] = Field(
|
||||||
|
None,
|
||||||
|
description='The number of images to generate (1-10). Only 1 supported for dall-e-3.',
|
||||||
|
examples=[1],
|
||||||
|
)
|
||||||
|
output_compression: Optional[int] = Field(
|
||||||
|
None, description='Compression level for JPEG or WebP (0-100)', examples=[100]
|
||||||
|
)
|
||||||
|
output_format: Optional[OutputFormat] = Field(
|
||||||
|
None, description='Format of the output image', examples=['png']
|
||||||
|
)
|
||||||
|
prompt: str = Field(
|
||||||
|
...,
|
||||||
|
description='A text description of the desired image',
|
||||||
|
examples=['Draw a rocket in front of a blackhole in deep space'],
|
||||||
|
)
|
||||||
|
quality: Optional[Quality1] = Field(
|
||||||
|
None, description='The quality of the generated image', examples=['high']
|
||||||
|
)
|
||||||
|
response_format: Optional[ResponseFormat] = Field(
|
||||||
|
None, description='Response format of image data', examples=['b64_json']
|
||||||
|
)
|
||||||
|
size: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='Size of the image (e.g., 1024x1024, 1536x1024, auto)',
|
||||||
|
examples=['1024x1536'],
|
||||||
|
)
|
||||||
|
style: Optional[Style] = Field(
|
||||||
|
None, description='Style of the image (only for dall-e-3)', examples=['vivid']
|
||||||
|
)
|
||||||
|
user: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='A unique identifier for end-user monitoring',
|
||||||
|
examples=['user-1234'],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Datum1(BaseModel):
|
||||||
|
b64_json: Optional[str] = Field(None, description='Base64 encoded image data')
|
||||||
|
revised_prompt: Optional[str] = Field(None, description='Revised prompt')
|
||||||
|
url: Optional[str] = Field(None, description='URL of the image')
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIImageGenerationResponse(BaseModel):
|
||||||
|
data: Optional[List[Datum1]] = None
|
||||||
|
class User(BaseModel):
|
||||||
|
email: Optional[str] = Field(None, description='The email address for this user.')
|
||||||
|
id: Optional[str] = Field(None, description='The unique id for this user.')
|
||||||
|
isAdmin: Optional[bool] = Field(
|
||||||
|
None, description='Indicates if the user has admin privileges.'
|
||||||
|
)
|
||||||
|
isApproved: Optional[bool] = Field(
|
||||||
|
None, description='Indicates if the user is approved.'
|
||||||
|
)
|
||||||
|
name: Optional[str] = Field(None, description='The name for this user.')
|
||||||
337
comfy_api_nodes/apis/client.py
Normal file
337
comfy_api_nodes/apis/client.py
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
"""
|
||||||
|
API Client Framework for api.comfy.org.
|
||||||
|
|
||||||
|
This module provides a flexible framework for making API requests from ComfyUI nodes.
|
||||||
|
It supports both synchronous and asynchronous API operations with proper type validation.
|
||||||
|
|
||||||
|
Key Components:
|
||||||
|
--------------
|
||||||
|
1. ApiClient - Handles HTTP requests with authentication and error handling
|
||||||
|
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
|
||||||
|
3. ApiOperation - Executes a single synchronous API operation
|
||||||
|
|
||||||
|
Usage Examples:
|
||||||
|
--------------
|
||||||
|
|
||||||
|
# Example 1: Synchronous API Operation
|
||||||
|
# ------------------------------------
|
||||||
|
# For a simple API call that returns the result immediately:
|
||||||
|
|
||||||
|
# 1. Create the API client
|
||||||
|
api_client = ApiClient(
|
||||||
|
base_url="https://api.example.com",
|
||||||
|
api_key="your_api_key_here",
|
||||||
|
timeout=30.0,
|
||||||
|
verify_ssl=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Define the endpoint
|
||||||
|
user_info_endpoint = ApiEndpoint(
|
||||||
|
path="/v1/users/me",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest, # No request body needed
|
||||||
|
response_model=UserProfile, # Pydantic model for the response
|
||||||
|
query_params=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Create the request object
|
||||||
|
request = EmptyRequest()
|
||||||
|
|
||||||
|
# 4. Create and execute the operation
|
||||||
|
operation = ApiOperation(
|
||||||
|
endpoint=user_info_endpoint,
|
||||||
|
request=request
|
||||||
|
)
|
||||||
|
user_profile = operation.execute(client=api_client) # Returns immediately with the result
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
Type,
|
||||||
|
Optional,
|
||||||
|
Any,
|
||||||
|
TypeVar,
|
||||||
|
Generic,
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from enum import Enum
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
R = TypeVar("R", bound=BaseModel)
|
||||||
|
|
||||||
|
class EmptyRequest(BaseModel):
|
||||||
|
"""Base class for empty request bodies.
|
||||||
|
For GET requests, fields will be sent as query parameters."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HttpMethod(str, Enum):
|
||||||
|
GET = "GET"
|
||||||
|
POST = "POST"
|
||||||
|
PUT = "PUT"
|
||||||
|
DELETE = "DELETE"
|
||||||
|
PATCH = "PATCH"
|
||||||
|
|
||||||
|
|
||||||
|
class ApiClient:
|
||||||
|
"""
|
||||||
|
Client for making HTTP requests to an API with authentication and error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
timeout: float = 30.0,
|
||||||
|
verify_ssl: bool = True,
|
||||||
|
):
|
||||||
|
self.base_url = base_url
|
||||||
|
self.api_key = api_key
|
||||||
|
self.timeout = timeout
|
||||||
|
self.verify_ssl = verify_ssl
|
||||||
|
|
||||||
|
def get_headers(self) -> Dict[str, str]:
|
||||||
|
"""Get headers for API requests, including authentication if available"""
|
||||||
|
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||||
|
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
json: Optional[Dict[str, Any]] = None,
|
||||||
|
files: Optional[Dict[str, Any]] = None,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Make an HTTP request to the API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method (GET, POST, etc.)
|
||||||
|
path: API endpoint path (will be joined with base_url)
|
||||||
|
params: Query parameters
|
||||||
|
json: JSON body data
|
||||||
|
files: Files to upload
|
||||||
|
headers: Additional headers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed JSON response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
requests.RequestException: If the request fails
|
||||||
|
"""
|
||||||
|
url = urljoin(self.base_url, path)
|
||||||
|
self.check_auth_token(self.api_key)
|
||||||
|
# Combine default headers with any provided headers
|
||||||
|
request_headers = self.get_headers()
|
||||||
|
if headers:
|
||||||
|
request_headers.update(headers)
|
||||||
|
|
||||||
|
# Let requests handle the content type when files are present.
|
||||||
|
if files:
|
||||||
|
del request_headers["Content-Type"]
|
||||||
|
|
||||||
|
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
|
||||||
|
logging.debug(f"[DEBUG] Files: {files}")
|
||||||
|
logging.debug(f"[DEBUG] Params: {params}")
|
||||||
|
logging.debug(f"[DEBUG] Json: {json}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# If files are present, use data parameter instead of json
|
||||||
|
if files:
|
||||||
|
form_data = {}
|
||||||
|
if json:
|
||||||
|
form_data.update(json)
|
||||||
|
response = requests.request(
|
||||||
|
method=method,
|
||||||
|
url=url,
|
||||||
|
params=params,
|
||||||
|
data=form_data, # Use data instead of json
|
||||||
|
files=files,
|
||||||
|
headers=request_headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
verify=self.verify_ssl,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = requests.request(
|
||||||
|
method=method,
|
||||||
|
url=url,
|
||||||
|
params=params,
|
||||||
|
json=json,
|
||||||
|
headers=request_headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
verify=self.verify_ssl,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Raise exception for error status codes
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.ConnectionError:
|
||||||
|
raise Exception(
|
||||||
|
f"Unable to connect to the API server at {self.base_url}. Please check your internet connection or verify the service is available."
|
||||||
|
)
|
||||||
|
|
||||||
|
except requests.Timeout:
|
||||||
|
raise Exception(
|
||||||
|
f"Request timed out after {self.timeout} seconds. The server might be experiencing high load or the operation is taking longer than expected."
|
||||||
|
)
|
||||||
|
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
status_code = e.response.status_code if hasattr(e, "response") else None
|
||||||
|
error_message = f"HTTP Error: {str(e)}"
|
||||||
|
|
||||||
|
# Try to extract detailed error message from JSON response
|
||||||
|
try:
|
||||||
|
if hasattr(e, "response") and e.response.content:
|
||||||
|
error_json = e.response.json()
|
||||||
|
if "error" in error_json and "message" in error_json["error"]:
|
||||||
|
error_message = f"API Error: {error_json['error']['message']}"
|
||||||
|
if "type" in error_json["error"]:
|
||||||
|
error_message += f" (Type: {error_json['error']['type']})"
|
||||||
|
else:
|
||||||
|
error_message = f"API Error: {error_json}"
|
||||||
|
except Exception as json_error:
|
||||||
|
# If we can't parse the JSON, fall back to the original error message
|
||||||
|
logging.debug(f"[DEBUG] Failed to parse error response: {str(json_error)}")
|
||||||
|
|
||||||
|
logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})")
|
||||||
|
if hasattr(e, "response") and e.response.content:
|
||||||
|
logging.debug(f"[DEBUG] Response content: {e.response.content}")
|
||||||
|
if status_code == 401:
|
||||||
|
error_message = "Unauthorized: Please login first to use this node."
|
||||||
|
if status_code == 402:
|
||||||
|
error_message = "Payment Required: Please add credits to your account to use this node."
|
||||||
|
if status_code == 409:
|
||||||
|
error_message = "There is a problem with your account. Please contact support@comfy.org. "
|
||||||
|
if status_code == 429:
|
||||||
|
error_message = "Rate Limit Exceeded: Please try again later."
|
||||||
|
raise Exception(error_message)
|
||||||
|
|
||||||
|
# Parse and return JSON response
|
||||||
|
if response.content:
|
||||||
|
return response.json()
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def check_auth_token(self, auth_token):
|
||||||
|
"""Verify that an auth token is present."""
|
||||||
|
if auth_token is None:
|
||||||
|
raise Exception("Unauthorized: Please login first to use this node.")
|
||||||
|
return auth_token
|
||||||
|
|
||||||
|
|
||||||
|
class ApiEndpoint(Generic[T, R]):
|
||||||
|
"""Defines an API endpoint with its request and response types"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
method: HttpMethod,
|
||||||
|
request_model: Type[T],
|
||||||
|
response_model: Type[R],
|
||||||
|
query_params: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
"""Initialize an API endpoint definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The URL path for this endpoint, can include placeholders like {id}
|
||||||
|
method: The HTTP method to use (GET, POST, etc.)
|
||||||
|
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
|
||||||
|
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
|
||||||
|
query_params: Optional dictionary of query parameters to include in the request
|
||||||
|
"""
|
||||||
|
self.path = path
|
||||||
|
self.method = method
|
||||||
|
self.request_model = request_model
|
||||||
|
self.response_model = response_model
|
||||||
|
self.query_params = query_params or {}
|
||||||
|
|
||||||
|
|
||||||
|
class SynchronousOperation(Generic[T, R]):
|
||||||
|
"""
|
||||||
|
Represents a single synchronous API operation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: ApiEndpoint[T, R],
|
||||||
|
request: T,
|
||||||
|
files: Optional[Dict[str, Any]] = None,
|
||||||
|
api_base: str = "https://api.comfy.org",
|
||||||
|
auth_token: Optional[str] = None,
|
||||||
|
timeout: float = 604800.0,
|
||||||
|
verify_ssl: bool = True,
|
||||||
|
):
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.request = request
|
||||||
|
self.response = None
|
||||||
|
self.error = None
|
||||||
|
self.api_base = api_base
|
||||||
|
self.auth_token = auth_token
|
||||||
|
self.timeout = timeout
|
||||||
|
self.verify_ssl = verify_ssl
|
||||||
|
self.files = files
|
||||||
|
def execute(self, client: Optional[ApiClient] = None) -> R:
|
||||||
|
"""Execute the API operation using the provided client or create one"""
|
||||||
|
try:
|
||||||
|
# Create client if not provided
|
||||||
|
if client is None:
|
||||||
|
if self.api_base is None:
|
||||||
|
raise ValueError("Either client or api_base must be provided")
|
||||||
|
client = ApiClient(
|
||||||
|
base_url=self.api_base,
|
||||||
|
api_key=self.auth_token,
|
||||||
|
timeout=self.timeout,
|
||||||
|
verify_ssl=self.verify_ssl,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert request model to dict, but use None for EmptyRequest
|
||||||
|
request_dict = None if isinstance(self.request, EmptyRequest) else self.request.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
# Debug log for request
|
||||||
|
logging.debug(f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}")
|
||||||
|
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
|
||||||
|
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
|
||||||
|
|
||||||
|
# Make the request
|
||||||
|
resp = client.request(
|
||||||
|
method=self.endpoint.method.value,
|
||||||
|
path=self.endpoint.path,
|
||||||
|
json=request_dict,
|
||||||
|
params=self.endpoint.query_params,
|
||||||
|
files=self.files,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Debug log for response
|
||||||
|
logging.debug("=" * 50)
|
||||||
|
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
||||||
|
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
||||||
|
logging.debug(f"[DEBUG] Response Body: {json.dumps(resp, indent=2)}")
|
||||||
|
logging.debug("=" * 50)
|
||||||
|
|
||||||
|
# Parse and return the response
|
||||||
|
return self._parse_response(resp)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug(f"[DEBUG] API Exception: {str(e)}")
|
||||||
|
raise Exception(str(e))
|
||||||
|
|
||||||
|
def _parse_response(self, resp):
|
||||||
|
"""Parse response data - can be overridden by subclasses"""
|
||||||
|
# The response is already the complete object, don't extract just the "data" field
|
||||||
|
# as that would lose the outer structure (created timestamp, etc.)
|
||||||
|
|
||||||
|
# Parse response using the provided model
|
||||||
|
self.response = self.endpoint.response_model.model_validate(resp)
|
||||||
|
logging.debug(f"[DEBUG] Parsed Response: {self.response}")
|
||||||
|
return self.response
|
||||||
434
comfy_api_nodes/nodes_api.py
Normal file
434
comfy_api_nodes/nodes_api.py
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
import io
|
||||||
|
from inspect import cleandoc
|
||||||
|
|
||||||
|
from comfy.utils import common_upscale
|
||||||
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||||
|
from comfy_api_nodes.apis import (
|
||||||
|
OpenAIImageGenerationRequest,
|
||||||
|
OpenAIImageEditRequest,
|
||||||
|
OpenAIImageGenerationResponse
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apis.client import ApiEndpoint, HttpMethod, SynchronousOperation
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import base64
|
||||||
|
|
||||||
|
def downscale_input(image):
|
||||||
|
samples = image.movedim(-1,1)
|
||||||
|
#downscaling input images to roughly the same size as the outputs
|
||||||
|
total = int(1536 * 1024)
|
||||||
|
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||||
|
if scale_by >= 1:
|
||||||
|
return image
|
||||||
|
width = round(samples.shape[3] * scale_by)
|
||||||
|
height = round(samples.shape[2] * scale_by)
|
||||||
|
|
||||||
|
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||||
|
s = s.movedim(1,-1)
|
||||||
|
return s
|
||||||
|
|
||||||
|
def validate_and_cast_response (response):
|
||||||
|
# validate raw JSON response
|
||||||
|
data = response.data
|
||||||
|
if not data or len(data) == 0:
|
||||||
|
raise Exception("No images returned from API endpoint")
|
||||||
|
|
||||||
|
# Get base64 image data
|
||||||
|
image_url = data[0].url
|
||||||
|
b64_data = data[0].b64_json
|
||||||
|
if not image_url and not b64_data:
|
||||||
|
raise Exception("No image was generated in the response")
|
||||||
|
|
||||||
|
if b64_data:
|
||||||
|
img_data = base64.b64decode(b64_data)
|
||||||
|
img = Image.open(io.BytesIO(img_data))
|
||||||
|
|
||||||
|
elif image_url:
|
||||||
|
img_response = requests.get(image_url)
|
||||||
|
if img_response.status_code != 200:
|
||||||
|
raise Exception("Failed to download the image")
|
||||||
|
img = Image.open(io.BytesIO(img_response.content))
|
||||||
|
|
||||||
|
img = img.convert("RGBA")
|
||||||
|
|
||||||
|
# Convert to numpy array, normalize to float32 between 0 and 1
|
||||||
|
img_array = np.array(img).astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
# Convert to torch tensor and add batch dimension
|
||||||
|
return torch.from_numpy(img_array)[None,]
|
||||||
|
|
||||||
|
class OpenAIDalle2(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
|
||||||
|
|
||||||
|
Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived,
|
||||||
|
so download or cache results if you need to keep them.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (IO.STRING, {
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Text prompt for DALL·E",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"seed": (IO.INT, {
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2**31-1,
|
||||||
|
"step": 1,
|
||||||
|
"display": "number",
|
||||||
|
"tooltip": "not implemented yet in backend",
|
||||||
|
}),
|
||||||
|
"size": (IO.COMBO, {
|
||||||
|
"options": ["256x256", "512x512", "1024x1024"],
|
||||||
|
"default": "1024x1024",
|
||||||
|
"tooltip": "Image size",
|
||||||
|
}),
|
||||||
|
"n": (IO.INT, {
|
||||||
|
"default": 1,
|
||||||
|
"min": 1,
|
||||||
|
"max": 8,
|
||||||
|
"step": 1,
|
||||||
|
"display": "number",
|
||||||
|
"tooltip": "How many images to generate",
|
||||||
|
}),
|
||||||
|
"image": (IO.IMAGE, {
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional reference image for image editing.",
|
||||||
|
}),
|
||||||
|
"mask": (IO.MASK, {
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
CATEGORY = "api node"
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
|
API_NODE = True
|
||||||
|
|
||||||
|
def api_call(self, prompt, seed=0, image=None, mask=None, n=1, size="1024x1024", auth_token=None):
|
||||||
|
model = "dall-e-2"
|
||||||
|
path = "/proxy/openai/images/generations"
|
||||||
|
request_class = OpenAIImageGenerationRequest
|
||||||
|
img_binary = None
|
||||||
|
|
||||||
|
if image is not None and mask is not None:
|
||||||
|
path = "/proxy/openai/images/edits"
|
||||||
|
request_class = OpenAIImageEditRequest
|
||||||
|
|
||||||
|
input_tensor = image.squeeze().cpu()
|
||||||
|
height, width, channels = input_tensor.shape
|
||||||
|
rgba_tensor = torch.ones(height, width, 4, device="cpu")
|
||||||
|
rgba_tensor[:, :, :channels] = input_tensor
|
||||||
|
|
||||||
|
if mask.shape[1:] != image.shape[1:-1]:
|
||||||
|
raise Exception("Mask and Image must be the same size")
|
||||||
|
rgba_tensor[:,:,3] = (1-mask.squeeze().cpu())
|
||||||
|
|
||||||
|
rgba_tensor = downscale_input(rgba_tensor.unsqueeze(0)).squeeze()
|
||||||
|
|
||||||
|
image_np = (rgba_tensor.numpy() * 255).astype(np.uint8)
|
||||||
|
img = Image.fromarray(image_np)
|
||||||
|
img_byte_arr = io.BytesIO()
|
||||||
|
img.save(img_byte_arr, format='PNG')
|
||||||
|
img_byte_arr.seek(0)
|
||||||
|
img_binary = img_byte_arr#.getvalue()
|
||||||
|
img_binary.name = "image.png"
|
||||||
|
elif image is not None or mask is not None:
|
||||||
|
raise Exception("Dall-E 2 image editing requires an image AND a mask")
|
||||||
|
|
||||||
|
# Build the operation
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=path,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=request_class,
|
||||||
|
response_model=OpenAIImageGenerationResponse
|
||||||
|
),
|
||||||
|
request=request_class(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
n=n,
|
||||||
|
size=size,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
files={
|
||||||
|
"image": img_binary,
|
||||||
|
} if img_binary else None,
|
||||||
|
auth_token=auth_token
|
||||||
|
)
|
||||||
|
|
||||||
|
response = operation.execute()
|
||||||
|
|
||||||
|
img_tensor = validate_and_cast_response(response)
|
||||||
|
return (img_tensor,)
|
||||||
|
|
||||||
|
class OpenAIDalle3(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
|
||||||
|
|
||||||
|
Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived,
|
||||||
|
so download or cache results if you need to keep them.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (IO.STRING, {
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Text prompt for DALL·E",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"seed": (IO.INT, {
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2**31-1,
|
||||||
|
"step": 1,
|
||||||
|
"display": "number",
|
||||||
|
"tooltip": "not implemented yet in backend",
|
||||||
|
}),
|
||||||
|
"quality" : (IO.COMBO, {
|
||||||
|
"options": ["standard","hd"],
|
||||||
|
"default": "standard",
|
||||||
|
"tooltip": "Image quality",
|
||||||
|
}),
|
||||||
|
"style": (IO.COMBO, {
|
||||||
|
"options": ["natural","vivid"],
|
||||||
|
"default": "natural",
|
||||||
|
"tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.",
|
||||||
|
}),
|
||||||
|
"size": (IO.COMBO, {
|
||||||
|
"options": ["1024x1024", "1024x1792", "1792x1024"],
|
||||||
|
"default": "1024x1024",
|
||||||
|
"tooltip": "Image size",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
CATEGORY = "api node"
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
|
API_NODE = True
|
||||||
|
|
||||||
|
def api_call(self, prompt, seed=0, style="natural", quality="standard", size="1024x1024", auth_token=None):
|
||||||
|
model = "dall-e-3"
|
||||||
|
|
||||||
|
# build the operation
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/openai/images/generations",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=OpenAIImageGenerationRequest,
|
||||||
|
response_model=OpenAIImageGenerationResponse
|
||||||
|
),
|
||||||
|
request=OpenAIImageGenerationRequest(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
quality=quality,
|
||||||
|
size=size,
|
||||||
|
style=style,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
auth_token=auth_token
|
||||||
|
)
|
||||||
|
|
||||||
|
response = operation.execute()
|
||||||
|
|
||||||
|
img_tensor = validate_and_cast_response(response)
|
||||||
|
return (img_tensor,)
|
||||||
|
|
||||||
|
class OpenAIGPTImage1(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
|
||||||
|
|
||||||
|
Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived,
|
||||||
|
so download or cache results if you need to keep them.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (IO.STRING, {
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Text prompt for GPT Image 1",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"seed": (IO.INT, {
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 2**31-1,
|
||||||
|
"step": 1,
|
||||||
|
"display": "number",
|
||||||
|
"tooltip": "not implemented yet in backend",
|
||||||
|
}),
|
||||||
|
"quality": (IO.COMBO, {
|
||||||
|
"options": ["low","medium","high"],
|
||||||
|
"default": "low",
|
||||||
|
"tooltip": "Image quality, affects cost and generation time.",
|
||||||
|
}),
|
||||||
|
"background": (IO.COMBO, {
|
||||||
|
"options": ["opaque","transparent"],
|
||||||
|
"default": "opaque",
|
||||||
|
"tooltip": "Return image with or without background",
|
||||||
|
}),
|
||||||
|
"size": (IO.COMBO, {
|
||||||
|
"options": ["auto", "1024x1024", "1024x1536", "1536x1024"],
|
||||||
|
"default": "auto",
|
||||||
|
"tooltip": "Image size",
|
||||||
|
}),
|
||||||
|
"n": (IO.INT, {
|
||||||
|
"default": 1,
|
||||||
|
"min": 1,
|
||||||
|
"max": 8,
|
||||||
|
"step": 1,
|
||||||
|
"display": "number",
|
||||||
|
"tooltip": "How many images to generate",
|
||||||
|
}),
|
||||||
|
"image": (IO.IMAGE, {
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional reference image for image editing.",
|
||||||
|
}),
|
||||||
|
"mask": (IO.MASK, {
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
CATEGORY = "api node"
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
|
API_NODE = True
|
||||||
|
|
||||||
|
def api_call(self, prompt, seed=0, quality="low", background="opaque", image=None, mask=None, n=1, size="1024x1024", auth_token=None):
|
||||||
|
model = "gpt-image-1"
|
||||||
|
path = "/proxy/openai/images/generations"
|
||||||
|
request_class = OpenAIImageGenerationRequest
|
||||||
|
img_binaries = []
|
||||||
|
mask_binary = None
|
||||||
|
files = []
|
||||||
|
|
||||||
|
if image is not None:
|
||||||
|
path = "/proxy/openai/images/edits"
|
||||||
|
request_class = OpenAIImageEditRequest
|
||||||
|
|
||||||
|
batch_size = image.shape[0]
|
||||||
|
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
single_image = image[i:i+1]
|
||||||
|
scaled_image = downscale_input(single_image).squeeze()
|
||||||
|
|
||||||
|
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||||
|
img = Image.fromarray(image_np)
|
||||||
|
img_byte_arr = io.BytesIO()
|
||||||
|
img.save(img_byte_arr, format='PNG')
|
||||||
|
img_byte_arr.seek(0)
|
||||||
|
img_binary = img_byte_arr
|
||||||
|
img_binary.name = f"image_{i}.png"
|
||||||
|
|
||||||
|
img_binaries.append(img_binary)
|
||||||
|
if batch_size == 1:
|
||||||
|
files.append(("image", img_binary))
|
||||||
|
else:
|
||||||
|
files.append(("image[]", img_binary))
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if image.shape[0] != 1:
|
||||||
|
raise Exception("Cannot use a mask with multiple image")
|
||||||
|
if image is None:
|
||||||
|
raise Exception("Cannot use a mask without an input image")
|
||||||
|
if mask.shape[1:] != image.shape[1:-1]:
|
||||||
|
raise Exception("Mask and Image must be the same size")
|
||||||
|
batch, height, width = mask.shape
|
||||||
|
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||||
|
rgba_mask[:,:,3] = (1-mask.squeeze().cpu())
|
||||||
|
|
||||||
|
scaled_mask = downscale_input(rgba_mask.unsqueeze(0)).squeeze()
|
||||||
|
|
||||||
|
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||||
|
mask_img = Image.fromarray(mask_np)
|
||||||
|
mask_img_byte_arr = io.BytesIO()
|
||||||
|
mask_img.save(mask_img_byte_arr, format='PNG')
|
||||||
|
mask_img_byte_arr.seek(0)
|
||||||
|
mask_binary = mask_img_byte_arr
|
||||||
|
mask_binary.name = "mask.png"
|
||||||
|
files.append(("mask", mask_binary))
|
||||||
|
|
||||||
|
|
||||||
|
# Build the operation
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=path,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=request_class,
|
||||||
|
response_model=OpenAIImageGenerationResponse
|
||||||
|
),
|
||||||
|
request=request_class(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
quality=quality,
|
||||||
|
background=background,
|
||||||
|
n=n,
|
||||||
|
seed=seed,
|
||||||
|
size=size,
|
||||||
|
),
|
||||||
|
files=files if files else None,
|
||||||
|
auth_token=auth_token
|
||||||
|
)
|
||||||
|
|
||||||
|
response = operation.execute()
|
||||||
|
|
||||||
|
img_tensor = validate_and_cast_response(response)
|
||||||
|
return (img_tensor,)
|
||||||
|
|
||||||
|
|
||||||
|
# A dictionary that contains all nodes you want to export with their names
|
||||||
|
# NOTE: names should be globally unique
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"OpenAIDalle2": OpenAIDalle2,
|
||||||
|
"OpenAIDalle3": OpenAIDalle3,
|
||||||
|
"OpenAIGPTImage1": OpenAIGPTImage1,
|
||||||
|
}
|
||||||
|
|
||||||
|
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"OpenAIDalle2": "OpenAI DALL·E 2",
|
||||||
|
"OpenAIDalle3": "OpenAI DALL·E 3",
|
||||||
|
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
||||||
|
}
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
import nodes
|
from __future__ import annotations
|
||||||
|
from typing import Type, Literal
|
||||||
|
|
||||||
|
import nodes
|
||||||
from comfy_execution.graph_utils import is_link
|
from comfy_execution.graph_utils import is_link
|
||||||
|
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||||
|
|
||||||
class DependencyCycleError(Exception):
|
class DependencyCycleError(Exception):
|
||||||
pass
|
pass
|
||||||
@@ -54,7 +57,22 @@ class DynamicPrompt:
|
|||||||
def get_original_prompt(self):
|
def get_original_prompt(self):
|
||||||
return self.original_prompt
|
return self.original_prompt
|
||||||
|
|
||||||
def get_input_info(class_def, input_name, valid_inputs=None):
|
def get_input_info(
|
||||||
|
class_def: Type[ComfyNodeABC],
|
||||||
|
input_name: str,
|
||||||
|
valid_inputs: InputTypeDict | None = None
|
||||||
|
) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
|
||||||
|
"""Get the input type, category, and extra info for a given input name.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
class_def: The class definition of the node.
|
||||||
|
input_name: The name of the input to get info for.
|
||||||
|
valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name.
|
||||||
|
"""
|
||||||
|
|
||||||
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
|
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
|
||||||
input_info = None
|
input_info = None
|
||||||
input_category = None
|
input_category = None
|
||||||
@@ -126,7 +144,7 @@ class TopologicalSort:
|
|||||||
from_node_id, from_socket = value
|
from_node_id, from_socket = value
|
||||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||||
continue
|
continue
|
||||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
_, _, input_info = self.get_input_info(unique_id, input_name)
|
||||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||||
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||||
node_ids.append(from_node_id)
|
node_ids.append(from_node_id)
|
||||||
|
|||||||
100
comfy_extras/nodes_fresca.py
Normal file
100
comfy_extras/nodes_fresca.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# Code based on https://github.com/WikiChao/FreSca (MIT License)
|
||||||
|
import torch
|
||||||
|
import torch.fft as fft
|
||||||
|
|
||||||
|
|
||||||
|
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
|
||||||
|
"""
|
||||||
|
Apply frequency-dependent scaling to an image tensor using Fourier transforms.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
x: Input tensor of shape (B, C, H, W)
|
||||||
|
scale_low: Scaling factor for low-frequency components (default: 1.0)
|
||||||
|
scale_high: Scaling factor for high-frequency components (default: 1.5)
|
||||||
|
freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied.
|
||||||
|
"""
|
||||||
|
# Preserve input dtype and device
|
||||||
|
dtype, device = x.dtype, x.device
|
||||||
|
|
||||||
|
# Convert to float32 for FFT computations
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
|
||||||
|
# 1) Apply FFT and shift low frequencies to center
|
||||||
|
x_freq = fft.fftn(x, dim=(-2, -1))
|
||||||
|
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
|
||||||
|
|
||||||
|
# Initialize mask with high-frequency scaling factor
|
||||||
|
mask = torch.ones(x_freq.shape, device=device) * scale_high
|
||||||
|
m = mask
|
||||||
|
for d in range(len(x_freq.shape) - 2):
|
||||||
|
dim = d + 2
|
||||||
|
cc = x_freq.shape[dim] // 2
|
||||||
|
f_c = min(freq_cutoff, cc)
|
||||||
|
m = m.narrow(dim, cc - f_c, f_c * 2)
|
||||||
|
|
||||||
|
# Apply low-frequency scaling factor to center region
|
||||||
|
m[:] = scale_low
|
||||||
|
|
||||||
|
# 3) Apply frequency-specific scaling
|
||||||
|
x_freq = x_freq * mask
|
||||||
|
|
||||||
|
# 4) Convert back to spatial domain
|
||||||
|
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
|
||||||
|
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
|
||||||
|
|
||||||
|
# 5) Restore original dtype
|
||||||
|
x_filtered = x_filtered.to(dtype)
|
||||||
|
|
||||||
|
return x_filtered
|
||||||
|
|
||||||
|
|
||||||
|
class FreSca:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL",),
|
||||||
|
"scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01,
|
||||||
|
"tooltip": "Scaling factor for low-frequency components"}),
|
||||||
|
"scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01,
|
||||||
|
"tooltip": "Scaling factor for high-frequency components"}),
|
||||||
|
"freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1,
|
||||||
|
"tooltip": "Number of frequency indices around center to consider as low-frequency"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
|
||||||
|
def patch(self, model, scale_low, scale_high, freq_cutoff):
|
||||||
|
def custom_cfg_function(args):
|
||||||
|
cond = args["conds_out"][0]
|
||||||
|
uncond = args["conds_out"][1]
|
||||||
|
|
||||||
|
guidance = cond - uncond
|
||||||
|
filtered_guidance = Fourier_filter(
|
||||||
|
guidance,
|
||||||
|
scale_low=scale_low,
|
||||||
|
scale_high=scale_high,
|
||||||
|
freq_cutoff=freq_cutoff,
|
||||||
|
)
|
||||||
|
filtered_cond = filtered_guidance + uncond
|
||||||
|
|
||||||
|
return [filtered_cond, uncond]
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_sampler_pre_cfg_function(custom_cfg_function)
|
||||||
|
|
||||||
|
return (m,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"FreSca": FreSca,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"FreSca": "FreSca",
|
||||||
|
}
|
||||||
@@ -26,7 +26,30 @@ class QuadrupleCLIPLoader:
|
|||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
|
class CLIPTextEncodeHiDream:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"clip": ("CLIP", ),
|
||||||
|
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
"llama": ("STRING", {"multiline": True, "dynamicPrompts": True})
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning"
|
||||||
|
|
||||||
|
def encode(self, clip, clip_l, clip_g, t5xxl, llama):
|
||||||
|
|
||||||
|
tokens = clip.tokenize(clip_g)
|
||||||
|
tokens["l"] = clip.tokenize(clip_l)["l"]
|
||||||
|
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||||
|
tokens["llama"] = clip.tokenize(llama)["llama"]
|
||||||
|
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
|
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
|
||||||
|
"CLIPTextEncodeHiDream": CLIPTextEncodeHiDream,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ class Load3D():
|
|||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA")
|
||||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart")
|
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info")
|
||||||
|
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
@@ -41,7 +41,7 @@ class Load3D():
|
|||||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||||
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
||||||
|
|
||||||
return output_image, output_mask, model_file, normal_image, lineart_image
|
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info']
|
||||||
|
|
||||||
class Load3DAnimation():
|
class Load3DAnimation():
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -59,8 +59,8 @@ class Load3DAnimation():
|
|||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA")
|
||||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal")
|
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info")
|
||||||
|
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
@@ -77,13 +77,16 @@ class Load3DAnimation():
|
|||||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||||
|
|
||||||
return output_image, output_mask, model_file, normal_image
|
return output_image, output_mask, model_file, normal_image, image['camera_info']
|
||||||
|
|
||||||
class Preview3D():
|
class Preview3D():
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"camera_info": ("LOAD3D_CAMERA", {})
|
||||||
}}
|
}}
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
@@ -95,13 +98,22 @@ class Preview3D():
|
|||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
def process(self, model_file, **kwargs):
|
def process(self, model_file, **kwargs):
|
||||||
return {"ui": {"model_file": [model_file]}, "result": ()}
|
camera_info = kwargs.get("camera_info", None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ui": {
|
||||||
|
"result": [model_file, camera_info]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class Preview3DAnimation():
|
class Preview3DAnimation():
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"camera_info": ("LOAD3D_CAMERA", {})
|
||||||
}}
|
}}
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
@@ -113,7 +125,13 @@ class Preview3DAnimation():
|
|||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
def process(self, model_file, **kwargs):
|
def process(self, model_file, **kwargs):
|
||||||
return {"ui": {"model_file": [model_file]}, "result": ()}
|
camera_info = kwargs.get("camera_info", None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ui": {
|
||||||
|
"result": [model_file, camera_info]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"Load3D": Load3D,
|
"Load3D": Load3D,
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ import scipy.ndimage
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import node_helpers
|
import node_helpers
|
||||||
|
import folder_paths
|
||||||
|
import random
|
||||||
|
|
||||||
|
import nodes
|
||||||
from nodes import MAX_RESOLUTION
|
from nodes import MAX_RESOLUTION
|
||||||
|
|
||||||
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
||||||
@@ -362,6 +365,30 @@ class ThresholdMask:
|
|||||||
mask = (mask > value).float()
|
mask = (mask > value).float()
|
||||||
return (mask,)
|
return (mask,)
|
||||||
|
|
||||||
|
# Mask Preview - original implement from
|
||||||
|
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
|
||||||
|
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
|
||||||
|
class MaskPreview(nodes.SaveImage):
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_temp_directory()
|
||||||
|
self.type = "temp"
|
||||||
|
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
||||||
|
self.compress_level = 4
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {"mask": ("MASK",), },
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||||
|
}
|
||||||
|
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "mask"
|
||||||
|
|
||||||
|
def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
|
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||||
|
return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"LatentCompositeMasked": LatentCompositeMasked,
|
"LatentCompositeMasked": LatentCompositeMasked,
|
||||||
@@ -376,6 +403,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"FeatherMask": FeatherMask,
|
"FeatherMask": FeatherMask,
|
||||||
"GrowMask": GrowMask,
|
"GrowMask": GrowMask,
|
||||||
"ThresholdMask": ThresholdMask,
|
"ThresholdMask": ThresholdMask,
|
||||||
|
"MaskPreview": MaskPreview
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
# Primitive nodes that are evaluated at backend.
|
# Primitive nodes that are evaluated at backend.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
|
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
|
||||||
|
|
||||||
|
|
||||||
@@ -23,7 +25,7 @@ class Int(ComfyNodeABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
return {
|
return {
|
||||||
"required": {"value": (IO.INT, {"control_after_generate": True})},
|
"required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = (IO.INT,)
|
RETURN_TYPES = (IO.INT,)
|
||||||
@@ -38,7 +40,7 @@ class Float(ComfyNodeABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
return {
|
return {
|
||||||
"required": {"value": (IO.FLOAT, {})},
|
"required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = (IO.FLOAT,)
|
RETURN_TYPES = (IO.FLOAT,)
|
||||||
|
|||||||
@@ -50,13 +50,15 @@ class SaveWEBM:
|
|||||||
for x in extra_pnginfo:
|
for x in extra_pnginfo:
|
||||||
container.metadata[x] = json.dumps(extra_pnginfo[x])
|
container.metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
codec_map = {"vp9": "libvpx-vp9", "av1": "libaom-av1"}
|
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
|
||||||
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
|
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
|
||||||
stream.width = images.shape[-2]
|
stream.width = images.shape[-2]
|
||||||
stream.height = images.shape[-3]
|
stream.height = images.shape[-3]
|
||||||
stream.pix_fmt = "yuv420p"
|
stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p"
|
||||||
stream.bit_rate = 0
|
stream.bit_rate = 0
|
||||||
stream.options = {'crf': str(crf)}
|
stream.options = {'crf': str(crf)}
|
||||||
|
if codec == "av1":
|
||||||
|
stream.options["preset"] = "6"
|
||||||
|
|
||||||
for frame in images:
|
for frame in images:
|
||||||
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
|
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
|
||||||
|
|||||||
@@ -193,9 +193,116 @@ class WanFunInpaintToVideo:
|
|||||||
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
||||||
|
|
||||||
|
|
||||||
|
class WanVaceToVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||||
|
},
|
||||||
|
"optional": {"control_video": ("IMAGE", ),
|
||||||
|
"control_masks": ("MASK", ),
|
||||||
|
"reference_image": ("IMAGE", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent", "trim_latent")
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None):
|
||||||
|
latent_length = ((length - 1) // 4) + 1
|
||||||
|
if control_video is not None:
|
||||||
|
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
if control_video.shape[0] < length:
|
||||||
|
control_video = torch.nn.functional.pad(control_video, (0, 0, 0, 0, 0, 0, 0, length - control_video.shape[0]), value=0.5)
|
||||||
|
else:
|
||||||
|
control_video = torch.ones((length, height, width, 3)) * 0.5
|
||||||
|
|
||||||
|
if reference_image is not None:
|
||||||
|
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
reference_image = vae.encode(reference_image[:, :, :, :3])
|
||||||
|
reference_image = torch.cat([reference_image, comfy.latent_formats.Wan21().process_out(torch.zeros_like(reference_image))], dim=1)
|
||||||
|
|
||||||
|
if control_masks is None:
|
||||||
|
mask = torch.ones((length, height, width, 1))
|
||||||
|
else:
|
||||||
|
mask = control_masks
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
mask = comfy.utils.common_upscale(mask[:length], width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
if mask.shape[0] < length:
|
||||||
|
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, 0, 0, length - mask.shape[0]), value=1.0)
|
||||||
|
|
||||||
|
control_video = control_video - 0.5
|
||||||
|
inactive = (control_video * (1 - mask)) + 0.5
|
||||||
|
reactive = (control_video * mask) + 0.5
|
||||||
|
|
||||||
|
inactive = vae.encode(inactive[:, :, :, :3])
|
||||||
|
reactive = vae.encode(reactive[:, :, :, :3])
|
||||||
|
control_video_latent = torch.cat((inactive, reactive), dim=1)
|
||||||
|
if reference_image is not None:
|
||||||
|
control_video_latent = torch.cat((reference_image, control_video_latent), dim=2)
|
||||||
|
|
||||||
|
vae_stride = 8
|
||||||
|
height_mask = height // vae_stride
|
||||||
|
width_mask = width // vae_stride
|
||||||
|
mask = mask.view(length, height_mask, vae_stride, width_mask, vae_stride)
|
||||||
|
mask = mask.permute(2, 4, 0, 1, 3)
|
||||||
|
mask = mask.reshape(vae_stride * vae_stride, length, height_mask, width_mask)
|
||||||
|
mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(latent_length, height_mask, width_mask), mode='nearest-exact').squeeze(0)
|
||||||
|
|
||||||
|
trim_latent = 0
|
||||||
|
if reference_image is not None:
|
||||||
|
mask_pad = torch.zeros_like(mask[:, :reference_image.shape[2], :, :])
|
||||||
|
mask = torch.cat((mask_pad, mask), dim=1)
|
||||||
|
latent_length += reference_image.shape[2]
|
||||||
|
trim_latent = reference_image.shape[2]
|
||||||
|
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
||||||
|
|
||||||
|
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return (positive, negative, out_latent, trim_latent)
|
||||||
|
|
||||||
|
class TrimVideoLatent:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples": ("LATENT",),
|
||||||
|
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "op"
|
||||||
|
|
||||||
|
CATEGORY = "latent/video"
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def op(self, samples, trim_amount):
|
||||||
|
samples_out = samples.copy()
|
||||||
|
|
||||||
|
s1 = samples["samples"]
|
||||||
|
samples_out["samples"] = s1[:, :, trim_amount:]
|
||||||
|
return (samples_out,)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"WanImageToVideo": WanImageToVideo,
|
"WanImageToVideo": WanImageToVideo,
|
||||||
"WanFunControlToVideo": WanFunControlToVideo,
|
"WanFunControlToVideo": WanFunControlToVideo,
|
||||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||||
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
||||||
|
"WanVaceToVideo": WanVaceToVideo,
|
||||||
|
"TrimVideoLatent": TrimVideoLatent,
|
||||||
}
|
}
|
||||||
|
|||||||
33
execution.py
33
execution.py
@@ -111,7 +111,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||||
def mark_missing():
|
def mark_missing():
|
||||||
missing_keys[x] = True
|
missing_keys[x] = True
|
||||||
input_data_all[x] = (None,)
|
input_data_all[x] = (None,)
|
||||||
@@ -144,6 +144,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||||
if h[x] == "UNIQUE_ID":
|
if h[x] == "UNIQUE_ID":
|
||||||
input_data_all[x] = [unique_id]
|
input_data_all[x] = [unique_id]
|
||||||
|
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||||||
|
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||||
return input_data_all, missing_keys
|
return input_data_all, missing_keys
|
||||||
|
|
||||||
map_node_over_list = None #Don't hook this please
|
map_node_over_list = None #Don't hook this please
|
||||||
@@ -574,7 +576,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
received_types = {}
|
received_types = {}
|
||||||
|
|
||||||
for x in valid_inputs:
|
for x in valid_inputs:
|
||||||
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
||||||
assert extra_info is not None
|
assert extra_info is not None
|
||||||
if x not in inputs:
|
if x not in inputs:
|
||||||
if input_category == "required":
|
if input_category == "required":
|
||||||
@@ -590,7 +592,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
val = inputs[x]
|
val = inputs[x]
|
||||||
info = (type_input, extra_info)
|
info = (input_type, extra_info)
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
if len(val) != 2:
|
if len(val) != 2:
|
||||||
error = {
|
error = {
|
||||||
@@ -611,8 +613,8 @@ def validate_inputs(prompt, item, validated):
|
|||||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||||
received_type = r[val[1]]
|
received_type = r[val[1]]
|
||||||
received_types[x] = received_type
|
received_types[x] = received_type
|
||||||
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
|
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type):
|
||||||
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
|
details = f"{x}, received_type({received_type}) mismatch input_type({input_type})"
|
||||||
error = {
|
error = {
|
||||||
"type": "return_type_mismatch",
|
"type": "return_type_mismatch",
|
||||||
"message": "Return type mismatch between linked nodes",
|
"message": "Return type mismatch between linked nodes",
|
||||||
@@ -660,22 +662,22 @@ def validate_inputs(prompt, item, validated):
|
|||||||
val = val["__value__"]
|
val = val["__value__"]
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
|
|
||||||
if type_input == "INT":
|
if input_type == "INT":
|
||||||
val = int(val)
|
val = int(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if type_input == "FLOAT":
|
if input_type == "FLOAT":
|
||||||
val = float(val)
|
val = float(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if type_input == "STRING":
|
if input_type == "STRING":
|
||||||
val = str(val)
|
val = str(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if type_input == "BOOLEAN":
|
if input_type == "BOOLEAN":
|
||||||
val = bool(val)
|
val = bool(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
error = {
|
error = {
|
||||||
"type": "invalid_input_type",
|
"type": "invalid_input_type",
|
||||||
"message": f"Failed to convert an input value to a {type_input} value",
|
"message": f"Failed to convert an input value to a {input_type} value",
|
||||||
"details": f"{x}, {val}, {ex}",
|
"details": f"{x}, {val}, {ex}",
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x,
|
"input_name": x,
|
||||||
@@ -715,18 +717,19 @@ def validate_inputs(prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(type_input, list):
|
if isinstance(input_type, list):
|
||||||
if val not in type_input:
|
combo_options = input_type
|
||||||
|
if val not in combo_options:
|
||||||
input_config = info
|
input_config = info
|
||||||
list_info = ""
|
list_info = ""
|
||||||
|
|
||||||
# Don't send back gigantic lists like if they're lots of
|
# Don't send back gigantic lists like if they're lots of
|
||||||
# scanned model filepaths
|
# scanned model filepaths
|
||||||
if len(type_input) > 20:
|
if len(combo_options) > 20:
|
||||||
list_info = f"(list of length {len(type_input)})"
|
list_info = f"(list of length {len(combo_options)})"
|
||||||
input_config = None
|
input_config = None
|
||||||
else:
|
else:
|
||||||
list_info = str(type_input)
|
list_info = str(combo_options)
|
||||||
|
|
||||||
error = {
|
error = {
|
||||||
"type": "value_not_in_list",
|
"type": "value_not_in_list",
|
||||||
|
|||||||
51
nodes.py
51
nodes.py
@@ -917,7 +917,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@@ -927,29 +927,10 @@ class CLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl"
|
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5"
|
||||||
|
|
||||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||||
if type == "stable_cascade":
|
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
|
||||||
elif type == "sd3":
|
|
||||||
clip_type = comfy.sd.CLIPType.SD3
|
|
||||||
elif type == "stable_audio":
|
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
|
|
||||||
elif type == "mochi":
|
|
||||||
clip_type = comfy.sd.CLIPType.MOCHI
|
|
||||||
elif type == "ltxv":
|
|
||||||
clip_type = comfy.sd.CLIPType.LTXV
|
|
||||||
elif type == "pixart":
|
|
||||||
clip_type = comfy.sd.CLIPType.PIXART
|
|
||||||
elif type == "cosmos":
|
|
||||||
clip_type = comfy.sd.CLIPType.COSMOS
|
|
||||||
elif type == "lumina2":
|
|
||||||
clip_type = comfy.sd.CLIPType.LUMINA2
|
|
||||||
elif type == "wan":
|
|
||||||
clip_type = comfy.sd.CLIPType.WAN
|
|
||||||
else:
|
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
|
||||||
|
|
||||||
model_options = {}
|
model_options = {}
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
@@ -964,7 +945,7 @@ class DualCLIPLoader:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video"], ),
|
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@@ -974,19 +955,13 @@ class DualCLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
|
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||||
|
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||||
|
|
||||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
if type == "sdxl":
|
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
|
||||||
elif type == "sd3":
|
|
||||||
clip_type = comfy.sd.CLIPType.SD3
|
|
||||||
elif type == "flux":
|
|
||||||
clip_type = comfy.sd.CLIPType.FLUX
|
|
||||||
elif type == "hunyuan_video":
|
|
||||||
clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO
|
|
||||||
|
|
||||||
model_options = {}
|
model_options = {}
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
@@ -2281,7 +2256,13 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_primitive.py",
|
"nodes_primitive.py",
|
||||||
"nodes_cfg.py",
|
"nodes_cfg.py",
|
||||||
"nodes_optimalsteps.py",
|
"nodes_optimalsteps.py",
|
||||||
"nodes_hidream.py"
|
"nodes_hidream.py",
|
||||||
|
"nodes_fresca.py",
|
||||||
|
]
|
||||||
|
|
||||||
|
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
|
||||||
|
api_nodes_files = [
|
||||||
|
"nodes_api.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
@@ -2289,6 +2270,10 @@ def init_builtin_extra_nodes():
|
|||||||
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
||||||
import_failed.append(node_file)
|
import_failed.append(node_file)
|
||||||
|
|
||||||
|
for node_file in api_nodes_files:
|
||||||
|
if not load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
|
||||||
|
import_failed.append(node_file)
|
||||||
|
|
||||||
return import_failed
|
return import_failed
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.16.8
|
comfyui-frontend-package==1.17.10
|
||||||
comfyui-workflow-templates==0.1.1
|
comfyui-workflow-templates==0.1.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
@@ -22,4 +22,5 @@ psutil
|
|||||||
kornia>=0.7.1
|
kornia>=0.7.1
|
||||||
spandrel
|
spandrel
|
||||||
soundfile
|
soundfile
|
||||||
av
|
av>=14.1.0
|
||||||
|
pydantic~=2.0
|
||||||
|
|||||||
@@ -580,6 +580,9 @@ class PromptServer():
|
|||||||
info['deprecated'] = True
|
info['deprecated'] = True
|
||||||
if getattr(obj_class, "EXPERIMENTAL", False):
|
if getattr(obj_class, "EXPERIMENTAL", False):
|
||||||
info['experimental'] = True
|
info['experimental'] = True
|
||||||
|
|
||||||
|
if hasattr(obj_class, 'API_NODE'):
|
||||||
|
info['api_node'] = obj_class.API_NODE
|
||||||
return info
|
return info
|
||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
|
|||||||
Reference in New Issue
Block a user