Compare commits
1 Commits
v0.3.30
...
not_requir
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
03df573995 |
6
.github/workflows/stable-release.yml
vendored
6
.github/workflows/stable-release.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: ${{ inputs.git_tag }}
|
ref: ${{ inputs.git_tag }}
|
||||||
fetch-depth: 150
|
fetch-depth: 0
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- uses: actions/cache/restore@v4
|
- uses: actions/cache/restore@v4
|
||||||
id: cache
|
id: cache
|
||||||
@@ -70,7 +70,7 @@ jobs:
|
|||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||||
|
|
||||||
mkdir ComfyUI_windows_portable
|
mkdir ComfyUI_windows_portable
|
||||||
mv python_embeded ComfyUI_windows_portable
|
mv python_embeded ComfyUI_windows_portable
|
||||||
@@ -85,7 +85,7 @@ jobs:
|
|||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||||
|
|
||||||
cd ComfyUI_windows_portable
|
cd ComfyUI_windows_portable
|
||||||
|
|||||||
47
.github/workflows/update-api-stubs.yml
vendored
47
.github/workflows/update-api-stubs.yml
vendored
@@ -1,47 +0,0 @@
|
|||||||
name: Generate Pydantic Stubs from api.comfy.org
|
|
||||||
|
|
||||||
on:
|
|
||||||
schedule:
|
|
||||||
- cron: '0 0 * * 1'
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
generate-models:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install 'datamodel-code-generator[http]'
|
|
||||||
|
|
||||||
- name: Generate API models
|
|
||||||
run: |
|
|
||||||
datamodel-codegen --use-subclass-enum --url https://api.comfy.org/openapi --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
|
|
||||||
|
|
||||||
- name: Check for changes
|
|
||||||
id: git-check
|
|
||||||
run: |
|
|
||||||
git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
- name: Create Pull Request
|
|
||||||
if: steps.git-check.outputs.changes == 'true'
|
|
||||||
uses: peter-evans/create-pull-request@v5
|
|
||||||
with:
|
|
||||||
commit-message: 'chore: update API models from OpenAPI spec'
|
|
||||||
title: 'Update API models from api.comfy.org'
|
|
||||||
body: |
|
|
||||||
This PR updates the API models based on the latest api.comfy.org OpenAPI specification.
|
|
||||||
|
|
||||||
Generated automatically by the a Github workflow.
|
|
||||||
branch: update-api-stubs
|
|
||||||
delete-branch: true
|
|
||||||
base: main
|
|
||||||
@@ -56,7 +56,7 @@ jobs:
|
|||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||||
|
|
||||||
mkdir ComfyUI_windows_portable_nightly_pytorch
|
mkdir ComfyUI_windows_portable_nightly_pytorch
|
||||||
mv python_embeded ComfyUI_windows_portable_nightly_pytorch
|
mv python_embeded ComfyUI_windows_portable_nightly_pytorch
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 150
|
fetch-depth: 0
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- shell: bash
|
- shell: bash
|
||||||
run: |
|
run: |
|
||||||
@@ -67,7 +67,7 @@ jobs:
|
|||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||||
|
|
||||||
mkdir ComfyUI_windows_portable
|
mkdir ComfyUI_windows_portable
|
||||||
mv python_embeded ComfyUI_windows_portable
|
mv python_embeded ComfyUI_windows_portable
|
||||||
@@ -82,7 +82,7 @@ jobs:
|
|||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
||||||
|
|
||||||
cd ComfyUI_windows_portable
|
cd ComfyUI_windows_portable
|
||||||
|
|||||||
26
CODEOWNERS
26
CODEOWNERS
@@ -5,20 +5,20 @@
|
|||||||
# Inlined the team members for now.
|
# Inlined the team members for now.
|
||||||
|
|
||||||
# Maintainers
|
# Maintainers
|
||||||
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||||
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||||
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||||
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||||
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||||
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||||
|
|
||||||
# Python web server
|
# Python web server
|
||||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||||
|
|
||||||
# Node developers
|
# Node developers
|
||||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
|
||||||
|
|||||||
@@ -62,7 +62,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
||||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
|
||||||
- Video Models
|
- Video Models
|
||||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||||
@@ -216,9 +215,9 @@ Additional discussion and help can be found [here](https://github.com/comfyanony
|
|||||||
|
|
||||||
Nvidia users should install stable pytorch using this command:
|
Nvidia users should install stable pytorch using this command:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
|
||||||
|
|
||||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements.
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
||||||
|
|
||||||
|
|||||||
@@ -184,27 +184,6 @@ comfyui-frontend-package is not installed.
|
|||||||
)
|
)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def templates_path(cls) -> str:
|
|
||||||
try:
|
|
||||||
import comfyui_workflow_templates
|
|
||||||
|
|
||||||
return str(
|
|
||||||
importlib.resources.files(comfyui_workflow_templates) / "templates"
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
logging.error(
|
|
||||||
f"""
|
|
||||||
********** ERROR ***********
|
|
||||||
|
|
||||||
comfyui-workflow-templates is not installed.
|
|
||||||
|
|
||||||
{frontend_install_warning_message()}
|
|
||||||
|
|
||||||
********** ERROR ***********
|
|
||||||
""".strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -66,7 +66,6 @@ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diff
|
|||||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
||||||
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||||
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||||
fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.")
|
|
||||||
|
|
||||||
fpvae_group = parser.add_mutually_exclusive_group()
|
fpvae_group = parser.add_mutually_exclusive_group()
|
||||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ class Output:
|
|||||||
setattr(self, key, item)
|
setattr(self, key, item)
|
||||||
|
|
||||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
|
||||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||||
image = image.movedim(-1, 1)
|
image = image.movedim(-1, 1)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Comfy-specific type hinting"""
|
"""Comfy-specific type hinting"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Literal, TypedDict, Optional
|
from typing import Literal, TypedDict
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -115,11 +115,6 @@ class InputTypeOptions(TypedDict):
|
|||||||
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
||||||
tooltip: NotRequired[str]
|
tooltip: NotRequired[str]
|
||||||
"""Tooltip for the input (or widget), shown on pointer hover"""
|
"""Tooltip for the input (or widget), shown on pointer hover"""
|
||||||
socketless: NotRequired[bool]
|
|
||||||
"""All inputs (including widgets) have an input socket to connect links. When ``true``, if there is a widget for this input, no socket will be created.
|
|
||||||
Available from frontend v1.17.5
|
|
||||||
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548
|
|
||||||
"""
|
|
||||||
# class InputTypeNumber(InputTypeOptions):
|
# class InputTypeNumber(InputTypeOptions):
|
||||||
# default: float | int
|
# default: float | int
|
||||||
min: NotRequired[float]
|
min: NotRequired[float]
|
||||||
@@ -229,8 +224,6 @@ class ComfyNodeABC(ABC):
|
|||||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||||
DEPRECATED: bool
|
DEPRECATED: bool
|
||||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||||
API_NODE: Optional[bool]
|
|
||||||
"""Flags a node as an API node."""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -736,7 +736,6 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None, model_options={}):
|
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||||
model_options = model_options.copy()
|
|
||||||
if "global_average_pooling" not in model_options:
|
if "global_average_pooling" not in model_options:
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ class Dino2Embeddings(torch.nn.Module):
|
|||||||
def forward(self, pixel_values):
|
def forward(self, pixel_values):
|
||||||
x = self.patch_embeddings(pixel_values)
|
x = self.patch_embeddings(pixel_values)
|
||||||
# TODO: mask_token?
|
# TODO: mask_token?
|
||||||
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
@@ -8,12 +8,25 @@ from einops import repeat
|
|||||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from comfy.ldm.flux.math import apply_rope, rope
|
from comfy.ldm.flux.math import apply_rope
|
||||||
from comfy.ldm.flux.layers import LastLayer
|
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.ldm.common_dit
|
|
||||||
|
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
|
||||||
|
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||||
|
assert dim % 2 == 0, "The dimension must be even."
|
||||||
|
|
||||||
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||||
|
omega = 1.0 / (theta**scale)
|
||||||
|
|
||||||
|
batch_size, seq_length = pos.shape
|
||||||
|
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||||
|
cos_out = torch.cos(out)
|
||||||
|
sin_out = torch.sin(out)
|
||||||
|
|
||||||
|
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
||||||
|
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
||||||
|
return out.float()
|
||||||
|
|
||||||
|
|
||||||
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
||||||
@@ -71,6 +84,23 @@ class TimestepEmbed(nn.Module):
|
|||||||
return t_emb
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class OutEmbed(nn.Module):
|
||||||
|
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, adaln_input):
|
||||||
|
shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1)
|
||||||
|
x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
||||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
||||||
|
|
||||||
@@ -633,7 +663,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.final_layer = LastLayer(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
|
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
|
||||||
caption_projection = []
|
caption_projection = []
|
||||||
@@ -702,8 +732,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
control = None,
|
control = None,
|
||||||
transformer_options = {},
|
transformer_options = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
bs, c, h, w = x.shape
|
hidden_states = x
|
||||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
|
||||||
timesteps = t
|
timesteps = t
|
||||||
pooled_embeds = y
|
pooled_embeds = y
|
||||||
T5_encoder_hidden_states = context
|
T5_encoder_hidden_states = context
|
||||||
@@ -796,4 +825,4 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
||||||
output = self.final_layer(hidden_states, adaln_input)
|
output = self.final_layer(hidden_states, adaln_input)
|
||||||
output = self.unpatchify(output, img_sizes)
|
output = self.unpatchify(output, img_sizes)
|
||||||
return -output[:, :, :h, :w]
|
return -output
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class WanSelfAttention(nn.Module):
|
|||||||
|
|
||||||
class WanT2VCrossAttention(WanSelfAttention):
|
class WanT2VCrossAttention(WanSelfAttention):
|
||||||
|
|
||||||
def forward(self, x, context, **kwargs):
|
def forward(self, x, context):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
@@ -116,14 +116,14 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||||
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, context, context_img_len):
|
def forward(self, x, context):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
context(Tensor): Shape [B, L2, C]
|
context(Tensor): Shape [B, L2, C]
|
||||||
"""
|
"""
|
||||||
context_img = context[:, :context_img_len]
|
context_img = context[:, :257]
|
||||||
context = context[:, context_img_len:]
|
context = context[:, 257:]
|
||||||
|
|
||||||
# compute query, key, value
|
# compute query, key, value
|
||||||
q = self.norm_q(self.q(x))
|
q = self.norm_q(self.q(x))
|
||||||
@@ -193,7 +193,6 @@ class WanAttentionBlock(nn.Module):
|
|||||||
e,
|
e,
|
||||||
freqs,
|
freqs,
|
||||||
context,
|
context,
|
||||||
context_img_len=257,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -214,40 +213,12 @@ class WanAttentionBlock(nn.Module):
|
|||||||
x = x + y * e[2]
|
x = x + y * e[2]
|
||||||
|
|
||||||
# cross-attention & ffn
|
# cross-attention & ffn
|
||||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
x = x + self.cross_attn(self.norm3(x), context)
|
||||||
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
||||||
x = x + y * e[5]
|
x = x + y * e[5]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class VaceWanAttentionBlock(WanAttentionBlock):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cross_attn_type,
|
|
||||||
dim,
|
|
||||||
ffn_dim,
|
|
||||||
num_heads,
|
|
||||||
window_size=(-1, -1),
|
|
||||||
qk_norm=True,
|
|
||||||
cross_attn_norm=False,
|
|
||||||
eps=1e-6,
|
|
||||||
block_id=0,
|
|
||||||
operation_settings={}
|
|
||||||
):
|
|
||||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
|
||||||
self.block_id = block_id
|
|
||||||
if block_id == 0:
|
|
||||||
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
||||||
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
||||||
|
|
||||||
def forward(self, c, x, **kwargs):
|
|
||||||
if self.block_id == 0:
|
|
||||||
c = self.before_proj(c) + x
|
|
||||||
c = super().forward(c, **kwargs)
|
|
||||||
c_skip = self.after_proj(c)
|
|
||||||
return c_skip, c
|
|
||||||
|
|
||||||
|
|
||||||
class Head(nn.Module):
|
class Head(nn.Module):
|
||||||
|
|
||||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
||||||
@@ -279,7 +250,7 @@ class Head(nn.Module):
|
|||||||
|
|
||||||
class MLPProj(torch.nn.Module):
|
class MLPProj(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_dim, out_dim, flf_pos_embed_token_number=None, operation_settings={}):
|
def __init__(self, in_dim, out_dim, operation_settings={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.proj = torch.nn.Sequential(
|
self.proj = torch.nn.Sequential(
|
||||||
@@ -287,15 +258,7 @@ class MLPProj(torch.nn.Module):
|
|||||||
torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||||
operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||||
|
|
||||||
if flf_pos_embed_token_number is not None:
|
|
||||||
self.emb_pos = nn.Parameter(torch.empty((1, flf_pos_embed_token_number, in_dim), device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
|
||||||
else:
|
|
||||||
self.emb_pos = None
|
|
||||||
|
|
||||||
def forward(self, image_embeds):
|
def forward(self, image_embeds):
|
||||||
if self.emb_pos is not None:
|
|
||||||
image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device)
|
|
||||||
|
|
||||||
clip_extra_context_tokens = self.proj(image_embeds)
|
clip_extra_context_tokens = self.proj(image_embeds)
|
||||||
return clip_extra_context_tokens
|
return clip_extra_context_tokens
|
||||||
|
|
||||||
@@ -321,7 +284,6 @@ class WanModel(torch.nn.Module):
|
|||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
cross_attn_norm=True,
|
cross_attn_norm=True,
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
flf_pos_embed_token_number=None,
|
|
||||||
image_model=None,
|
image_model=None,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@@ -411,7 +373,7 @@ class WanModel(torch.nn.Module):
|
|||||||
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
|
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
|
||||||
|
|
||||||
if model_type == 'i2v':
|
if model_type == 'i2v':
|
||||||
self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings)
|
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
|
||||||
else:
|
else:
|
||||||
self.img_emb = None
|
self.img_emb = None
|
||||||
|
|
||||||
@@ -423,7 +385,6 @@ class WanModel(torch.nn.Module):
|
|||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
freqs=None,
|
freqs=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Forward pass through the diffusion model
|
Forward pass through the diffusion model
|
||||||
@@ -459,12 +420,9 @@ class WanModel(torch.nn.Module):
|
|||||||
# context
|
# context
|
||||||
context = self.text_embedding(context)
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
context_img_len = None
|
if clip_fea is not None and self.img_emb is not None:
|
||||||
if clip_fea is not None:
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
if self.img_emb is not None:
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
|
||||||
context = torch.concat([context_clip, context], dim=1)
|
|
||||||
context_img_len = clip_fea.shape[-2]
|
|
||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
@@ -472,12 +430,12 @@ class WanModel(torch.nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
x = block(x, e=e0, freqs=freqs, context=context)
|
||||||
|
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
@@ -486,7 +444,7 @@ class WanModel(torch.nn.Module):
|
|||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
@@ -500,7 +458,7 @@ class WanModel(torch.nn.Module):
|
|||||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
|
|
||||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
r"""
|
r"""
|
||||||
@@ -525,115 +483,3 @@ class WanModel(torch.nn.Module):
|
|||||||
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
|
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
|
||||||
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
||||||
return u
|
return u
|
||||||
|
|
||||||
|
|
||||||
class VaceWanModel(WanModel):
|
|
||||||
r"""
|
|
||||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
model_type='vace',
|
|
||||||
patch_size=(1, 2, 2),
|
|
||||||
text_len=512,
|
|
||||||
in_dim=16,
|
|
||||||
dim=2048,
|
|
||||||
ffn_dim=8192,
|
|
||||||
freq_dim=256,
|
|
||||||
text_dim=4096,
|
|
||||||
out_dim=16,
|
|
||||||
num_heads=16,
|
|
||||||
num_layers=32,
|
|
||||||
window_size=(-1, -1),
|
|
||||||
qk_norm=True,
|
|
||||||
cross_attn_norm=True,
|
|
||||||
eps=1e-6,
|
|
||||||
flf_pos_embed_token_number=None,
|
|
||||||
image_model=None,
|
|
||||||
vace_layers=None,
|
|
||||||
vace_in_dim=None,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
operations=None,
|
|
||||||
):
|
|
||||||
|
|
||||||
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
|
||||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
|
||||||
|
|
||||||
# Vace
|
|
||||||
if vace_layers is not None:
|
|
||||||
self.vace_layers = vace_layers
|
|
||||||
self.vace_in_dim = vace_in_dim
|
|
||||||
# vace blocks
|
|
||||||
self.vace_blocks = nn.ModuleList([
|
|
||||||
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, self.cross_attn_norm, self.eps, block_id=i, operation_settings=operation_settings)
|
|
||||||
for i in range(self.vace_layers)
|
|
||||||
])
|
|
||||||
|
|
||||||
self.vace_layers_mapping = {i: n for n, i in enumerate(range(0, self.num_layers, self.num_layers // self.vace_layers))}
|
|
||||||
# vace patch embeddings
|
|
||||||
self.vace_patch_embedding = operations.Conv3d(
|
|
||||||
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size, device=device, dtype=torch.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_orig(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
t,
|
|
||||||
context,
|
|
||||||
vace_context,
|
|
||||||
vace_strength=1.0,
|
|
||||||
clip_fea=None,
|
|
||||||
freqs=None,
|
|
||||||
transformer_options={},
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# embeddings
|
|
||||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
|
||||||
grid_sizes = x.shape[2:]
|
|
||||||
x = x.flatten(2).transpose(1, 2)
|
|
||||||
|
|
||||||
# time embeddings
|
|
||||||
e = self.time_embedding(
|
|
||||||
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
|
||||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
|
||||||
|
|
||||||
# context
|
|
||||||
context = self.text_embedding(context)
|
|
||||||
|
|
||||||
context_img_len = None
|
|
||||||
if clip_fea is not None:
|
|
||||||
if self.img_emb is not None:
|
|
||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
|
||||||
context = torch.concat([context_clip, context], dim=1)
|
|
||||||
context_img_len = clip_fea.shape[-2]
|
|
||||||
|
|
||||||
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
|
|
||||||
c = c.flatten(2).transpose(1, 2)
|
|
||||||
|
|
||||||
# arguments
|
|
||||||
x_orig = x
|
|
||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
|
||||||
for i, block in enumerate(self.blocks):
|
|
||||||
if ("double_block", i) in blocks_replace:
|
|
||||||
def block_wrap(args):
|
|
||||||
out = {}
|
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
|
||||||
return out
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
|
||||||
x = out["img"]
|
|
||||||
else:
|
|
||||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
|
||||||
|
|
||||||
ii = self.vace_layers_mapping.get(i, None)
|
|
||||||
if ii is not None:
|
|
||||||
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
|
||||||
x += c_skip * vace_strength
|
|
||||||
# head
|
|
||||||
x = self.head(x, e)
|
|
||||||
|
|
||||||
# unpatchify
|
|
||||||
x = self.unpatchify(x, grid_sizes)
|
|
||||||
return x
|
|
||||||
|
|||||||
321
comfy/lora.py
321
comfy/lora.py
@@ -20,7 +20,6 @@ from __future__ import annotations
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_base
|
import comfy.model_base
|
||||||
import comfy.weight_adapter as weight_adapter
|
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -50,12 +49,139 @@ def load_lora(lora, to_load, log_missing=True):
|
|||||||
dora_scale = lora[dora_scale_name]
|
dora_scale = lora[dora_scale_name]
|
||||||
loaded_keys.add(dora_scale_name)
|
loaded_keys.add(dora_scale_name)
|
||||||
|
|
||||||
for adapter_cls in weight_adapter.adapters:
|
reshape_name = "{}.reshape_weight".format(x)
|
||||||
adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys)
|
reshape = None
|
||||||
if adapter is not None:
|
if reshape_name in lora.keys():
|
||||||
patch_dict[to_load[x]] = adapter
|
try:
|
||||||
loaded_keys.update(adapter.loaded_keys)
|
reshape = lora[reshape_name].tolist()
|
||||||
continue
|
loaded_keys.add(reshape_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||||
|
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||||
|
mochi_lora = "{}.lora_B".format(x)
|
||||||
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
|
A_name = None
|
||||||
|
|
||||||
|
if regular_lora in lora.keys():
|
||||||
|
A_name = regular_lora
|
||||||
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
|
mid_name = "{}.lora_mid.weight".format(x)
|
||||||
|
elif diffusers_lora in lora.keys():
|
||||||
|
A_name = diffusers_lora
|
||||||
|
B_name = "{}_lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif diffusers2_lora in lora.keys():
|
||||||
|
A_name = diffusers2_lora
|
||||||
|
B_name = "{}.lora_A.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif diffusers3_lora in lora.keys():
|
||||||
|
A_name = diffusers3_lora
|
||||||
|
B_name = "{}.lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif mochi_lora in lora.keys():
|
||||||
|
A_name = mochi_lora
|
||||||
|
B_name = "{}.lora_A".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif transformers_lora in lora.keys():
|
||||||
|
A_name = transformers_lora
|
||||||
|
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
|
||||||
|
if A_name is not None:
|
||||||
|
mid = None
|
||||||
|
if mid_name is not None and mid_name in lora.keys():
|
||||||
|
mid = lora[mid_name]
|
||||||
|
loaded_keys.add(mid_name)
|
||||||
|
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
|
||||||
|
loaded_keys.add(A_name)
|
||||||
|
loaded_keys.add(B_name)
|
||||||
|
|
||||||
|
|
||||||
|
######## loha
|
||||||
|
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||||
|
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||||
|
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||||
|
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||||
|
hada_t1_name = "{}.hada_t1".format(x)
|
||||||
|
hada_t2_name = "{}.hada_t2".format(x)
|
||||||
|
if hada_w1_a_name in lora.keys():
|
||||||
|
hada_t1 = None
|
||||||
|
hada_t2 = None
|
||||||
|
if hada_t1_name in lora.keys():
|
||||||
|
hada_t1 = lora[hada_t1_name]
|
||||||
|
hada_t2 = lora[hada_t2_name]
|
||||||
|
loaded_keys.add(hada_t1_name)
|
||||||
|
loaded_keys.add(hada_t2_name)
|
||||||
|
|
||||||
|
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
||||||
|
loaded_keys.add(hada_w1_a_name)
|
||||||
|
loaded_keys.add(hada_w1_b_name)
|
||||||
|
loaded_keys.add(hada_w2_a_name)
|
||||||
|
loaded_keys.add(hada_w2_b_name)
|
||||||
|
|
||||||
|
|
||||||
|
######## lokr
|
||||||
|
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||||
|
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||||
|
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||||
|
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||||
|
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||||
|
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||||
|
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||||
|
|
||||||
|
lokr_w1 = None
|
||||||
|
if lokr_w1_name in lora.keys():
|
||||||
|
lokr_w1 = lora[lokr_w1_name]
|
||||||
|
loaded_keys.add(lokr_w1_name)
|
||||||
|
|
||||||
|
lokr_w2 = None
|
||||||
|
if lokr_w2_name in lora.keys():
|
||||||
|
lokr_w2 = lora[lokr_w2_name]
|
||||||
|
loaded_keys.add(lokr_w2_name)
|
||||||
|
|
||||||
|
lokr_w1_a = None
|
||||||
|
if lokr_w1_a_name in lora.keys():
|
||||||
|
lokr_w1_a = lora[lokr_w1_a_name]
|
||||||
|
loaded_keys.add(lokr_w1_a_name)
|
||||||
|
|
||||||
|
lokr_w1_b = None
|
||||||
|
if lokr_w1_b_name in lora.keys():
|
||||||
|
lokr_w1_b = lora[lokr_w1_b_name]
|
||||||
|
loaded_keys.add(lokr_w1_b_name)
|
||||||
|
|
||||||
|
lokr_w2_a = None
|
||||||
|
if lokr_w2_a_name in lora.keys():
|
||||||
|
lokr_w2_a = lora[lokr_w2_a_name]
|
||||||
|
loaded_keys.add(lokr_w2_a_name)
|
||||||
|
|
||||||
|
lokr_w2_b = None
|
||||||
|
if lokr_w2_b_name in lora.keys():
|
||||||
|
lokr_w2_b = lora[lokr_w2_b_name]
|
||||||
|
loaded_keys.add(lokr_w2_b_name)
|
||||||
|
|
||||||
|
lokr_t2 = None
|
||||||
|
if lokr_t2_name in lora.keys():
|
||||||
|
lokr_t2 = lora[lokr_t2_name]
|
||||||
|
loaded_keys.add(lokr_t2_name)
|
||||||
|
|
||||||
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||||
|
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
||||||
|
|
||||||
|
#glora
|
||||||
|
a1_name = "{}.a1.weight".format(x)
|
||||||
|
a2_name = "{}.a2.weight".format(x)
|
||||||
|
b1_name = "{}.b1.weight".format(x)
|
||||||
|
b2_name = "{}.b2.weight".format(x)
|
||||||
|
if a1_name in lora:
|
||||||
|
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
||||||
|
loaded_keys.add(a1_name)
|
||||||
|
loaded_keys.add(a2_name)
|
||||||
|
loaded_keys.add(b1_name)
|
||||||
|
loaded_keys.add(b2_name)
|
||||||
|
|
||||||
w_norm_name = "{}.w_norm".format(x)
|
w_norm_name = "{}.w_norm".format(x)
|
||||||
b_norm_name = "{}.b_norm".format(x)
|
b_norm_name = "{}.b_norm".format(x)
|
||||||
@@ -282,6 +408,26 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||||
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
|
lora_diff *= alpha
|
||||||
|
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||||
|
weight_norm = (
|
||||||
|
weight_calc.transpose(0, 1)
|
||||||
|
.reshape(weight_calc.shape[1], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||||
|
if strength != 1.0:
|
||||||
|
weight_calc -= weight
|
||||||
|
weight += strength * (weight_calc)
|
||||||
|
else:
|
||||||
|
weight[:] = weight_calc
|
||||||
|
return weight
|
||||||
|
|
||||||
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Pad a tensor to a new shape with zeros.
|
Pad a tensor to a new shape with zeros.
|
||||||
@@ -336,16 +482,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
||||||
|
|
||||||
if isinstance(v, weight_adapter.WeightAdapterBase):
|
|
||||||
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
|
|
||||||
if output is None:
|
|
||||||
logging.warning("Calculate Weight Failed: {} {}".format(v.name, key))
|
|
||||||
else:
|
|
||||||
weight = output
|
|
||||||
if old_weight is not None:
|
|
||||||
weight = old_weight
|
|
||||||
continue
|
|
||||||
|
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
patch_type = "diff"
|
patch_type = "diff"
|
||||||
elif len(v) == 2:
|
elif len(v) == 2:
|
||||||
@@ -372,6 +508,157 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||||||
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
||||||
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
||||||
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
||||||
|
elif patch_type == "lora": #lora/locon
|
||||||
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||||
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||||
|
dora_scale = v[4]
|
||||||
|
reshape = v[5]
|
||||||
|
|
||||||
|
if reshape is not None:
|
||||||
|
weight = pad_tensor_to_shape(weight, reshape)
|
||||||
|
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / mat2.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
if v[3] is not None:
|
||||||
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
|
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
||||||
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
|
try:
|
||||||
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "lokr":
|
||||||
|
w1 = v[0]
|
||||||
|
w2 = v[1]
|
||||||
|
w1_a = v[3]
|
||||||
|
w1_b = v[4]
|
||||||
|
w2_a = v[5]
|
||||||
|
w2_b = v[6]
|
||||||
|
t2 = v[7]
|
||||||
|
dora_scale = v[8]
|
||||||
|
dim = None
|
||||||
|
|
||||||
|
if w1 is None:
|
||||||
|
dim = w1_b.shape[0]
|
||||||
|
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if w2 is None:
|
||||||
|
dim = w2_b.shape[0]
|
||||||
|
if t2 is None:
|
||||||
|
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
if v[2] is not None and dim is not None:
|
||||||
|
alpha = v[2] / dim
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "loha":
|
||||||
|
w1a = v[0]
|
||||||
|
w1b = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / w1b.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
w2a = v[3]
|
||||||
|
w2b = v[4]
|
||||||
|
dora_scale = v[7]
|
||||||
|
if v[5] is not None: #cp decomposition
|
||||||
|
t1 = v[5]
|
||||||
|
t2 = v[6]
|
||||||
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||||
|
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "glora":
|
||||||
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
old_glora = False
|
||||||
|
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||||
|
rank = v[0].shape[0]
|
||||||
|
old_glora = True
|
||||||
|
|
||||||
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||||
|
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
old_glora = False
|
||||||
|
rank = v[1].shape[0]
|
||||||
|
|
||||||
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / rank
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if old_glora:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||||
|
else:
|
||||||
|
if weight.dim() > 2:
|
||||||
|
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
else:
|
||||||
|
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||||
|
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
else:
|
else:
|
||||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
|
|||||||
@@ -1043,37 +1043,6 @@ class WAN21(BaseModel):
|
|||||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class WAN21_Vace(WAN21):
|
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
|
||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
|
|
||||||
self.image_to_video = image_to_video
|
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
|
||||||
out = super().extra_conds(**kwargs)
|
|
||||||
noise = kwargs.get("noise", None)
|
|
||||||
noise_shape = list(noise.shape)
|
|
||||||
vace_frames = kwargs.get("vace_frames", None)
|
|
||||||
if vace_frames is None:
|
|
||||||
noise_shape[1] = 32
|
|
||||||
vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
|
|
||||||
|
|
||||||
for i in range(0, vace_frames.shape[1], 16):
|
|
||||||
vace_frames = vace_frames.clone()
|
|
||||||
vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16])
|
|
||||||
|
|
||||||
mask = kwargs.get("vace_mask", None)
|
|
||||||
if mask is None:
|
|
||||||
noise_shape[1] = 64
|
|
||||||
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
|
|
||||||
|
|
||||||
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
|
|
||||||
|
|
||||||
vace_strength = kwargs.get("vace_strength", 1.0)
|
|
||||||
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||||
|
|||||||
@@ -317,18 +317,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["cross_attn_norm"] = True
|
dit_config["cross_attn_norm"] = True
|
||||||
dit_config["eps"] = 1e-6
|
dit_config["eps"] = 1e-6
|
||||||
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||||
if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "vace"
|
dit_config["model_type"] = "i2v"
|
||||||
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
|
||||||
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
|
||||||
else:
|
else:
|
||||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
dit_config["model_type"] = "t2v"
|
||||||
dit_config["model_type"] = "i2v"
|
|
||||||
else:
|
|
||||||
dit_config["model_type"] = "t2v"
|
|
||||||
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
|
||||||
if flf_weight is not None:
|
|
||||||
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||||
|
|||||||
@@ -725,8 +725,6 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
if args.fp8_e5m2_unet:
|
if args.fp8_e5m2_unet:
|
||||||
return torch.float8_e5m2
|
return torch.float8_e5m2
|
||||||
if args.fp8_e8m0fnu_unet:
|
|
||||||
return torch.float8_e8m0fnu
|
|
||||||
|
|
||||||
fp8_dtype = None
|
fp8_dtype = None
|
||||||
if weight_dtype in FLOAT8_TYPES:
|
if weight_dtype in FLOAT8_TYPES:
|
||||||
|
|||||||
@@ -49,7 +49,6 @@ if RMSNorm is None:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.register_parameter("weight", None)
|
self.register_parameter("weight", None)
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return rms_norm(x, self.weight, self.eps)
|
return rms_norm(x, self.weight, self.eps)
|
||||||
|
|||||||
34
comfy/sd.py
34
comfy/sd.py
@@ -703,7 +703,6 @@ class CLIPType(Enum):
|
|||||||
COSMOS = 11
|
COSMOS = 11
|
||||||
LUMINA2 = 12
|
LUMINA2 = 12
|
||||||
WAN = 13
|
WAN = 13
|
||||||
HIDREAM = 14
|
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@@ -792,9 +791,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.SD3:
|
elif clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
elif clip_type == CLIPType.HIDREAM:
|
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
@@ -815,10 +811,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif clip_type == CLIPType.HIDREAM:
|
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
|
||||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
|
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
|
||||||
else: #CLIPType.MOCHI
|
else: #CLIPType.MOCHI
|
||||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||||
@@ -835,18 +827,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif te_model == TEModel.LLAMA3_8:
|
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
|
||||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
|
||||||
else:
|
else:
|
||||||
# clip_l
|
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
elif clip_type == CLIPType.HIDREAM:
|
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
@@ -864,24 +848,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.HUNYUAN_VIDEO:
|
elif clip_type == CLIPType.HUNYUAN_VIDEO:
|
||||||
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
|
||||||
elif clip_type == CLIPType.HIDREAM:
|
|
||||||
# Detect
|
|
||||||
hidream_dualclip_classes = []
|
|
||||||
for hidream_te in clip_data:
|
|
||||||
te_model = detect_te_model(hidream_te)
|
|
||||||
hidream_dualclip_classes.append(te_model)
|
|
||||||
|
|
||||||
clip_l = TEModel.CLIP_L in hidream_dualclip_classes
|
|
||||||
clip_g = TEModel.CLIP_G in hidream_dualclip_classes
|
|
||||||
t5 = TEModel.T5_XXL in hidream_dualclip_classes
|
|
||||||
llama = TEModel.LLAMA3_8 in hidream_dualclip_classes
|
|
||||||
|
|
||||||
# Initialize t5xxl_detect and llama_detect kwargs if needed
|
|
||||||
t5_kwargs = t5xxl_detect(clip_data) if t5 else {}
|
|
||||||
llama_kwargs = llama_detect(clip_data) if llama else {}
|
|
||||||
|
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
|
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
|
|||||||
@@ -987,16 +987,6 @@ class WAN21_FunControl2V(WAN21_T2V):
|
|||||||
out = model_base.WAN21(self, image_to_video=False, device=device)
|
out = model_base.WAN21(self, image_to_video=False, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class WAN21_Vace(WAN21_T2V):
|
|
||||||
unet_config = {
|
|
||||||
"image_model": "wan2.1",
|
|
||||||
"model_type": "vace",
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
|
||||||
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
|
||||||
return out
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@@ -1065,6 +1055,6 @@ class HiDream(supported_models_base.BASE):
|
|||||||
return None # TODO
|
return None # TODO
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@@ -11,15 +11,14 @@ class HiDreamTokenizer:
|
|||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data)
|
self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, tokenizer_data=tokenizer_data)
|
||||||
self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
|
self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||||
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||||
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
|
|
||||||
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
|
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -109,18 +108,14 @@ class HiDreamTEModel(torch.nn.Module):
|
|||||||
if self.t5xxl is not None:
|
if self.t5xxl is not None:
|
||||||
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||||
t5_out, t5_pooled = t5_output[:2]
|
t5_out, t5_pooled = t5_output[:2]
|
||||||
else:
|
|
||||||
t5_out = None
|
|
||||||
|
|
||||||
if self.llama is not None:
|
if self.llama is not None:
|
||||||
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
|
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||||
ll_out, ll_pooled = ll_output[:2]
|
ll_out, ll_pooled = ll_output[:2]
|
||||||
ll_out = ll_out[:, 1:]
|
ll_out = ll_out[:, 1:]
|
||||||
else:
|
|
||||||
ll_out = None
|
|
||||||
|
|
||||||
if t5_out is None:
|
if t5_out is None:
|
||||||
t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device())
|
t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
if ll_out is None:
|
if ll_out is None:
|
||||||
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
||||||
|
|||||||
@@ -32,9 +32,9 @@ def t5_xxl_detect(state_dict, prefix=""):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77, max_length=99999999):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=max_length, min_length=min_length, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=min_length, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
class SD3Tokenizer:
|
class SD3Tokenizer:
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
|
||||||
|
|
||||||
class SPieceTokenizer:
|
class SPieceTokenizer:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -16,8 +15,6 @@ class SPieceTokenizer:
|
|||||||
if isinstance(tokenizer_path, bytes):
|
if isinstance(tokenizer_path, bytes):
|
||||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||||
else:
|
else:
|
||||||
if not os.path.isfile(tokenizer_path):
|
|
||||||
raise ValueError("invalid tokenizer")
|
|
||||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||||
|
|
||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
from .base import WeightAdapterBase
|
|
||||||
from .lora import LoRAAdapter
|
|
||||||
from .loha import LoHaAdapter
|
|
||||||
from .lokr import LoKrAdapter
|
|
||||||
from .glora import GLoRAAdapter
|
|
||||||
from .oft import OFTAdapter
|
|
||||||
from .boft import BOFTAdapter
|
|
||||||
|
|
||||||
|
|
||||||
adapters: list[type[WeightAdapterBase]] = [
|
|
||||||
LoRAAdapter,
|
|
||||||
LoHaAdapter,
|
|
||||||
LoKrAdapter,
|
|
||||||
GLoRAAdapter,
|
|
||||||
OFTAdapter,
|
|
||||||
BOFTAdapter,
|
|
||||||
]
|
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
import comfy.model_management
|
|
||||||
|
|
||||||
|
|
||||||
class WeightAdapterBase:
|
|
||||||
name: str
|
|
||||||
loaded_keys: set[str]
|
|
||||||
weights: list[torch.Tensor]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def to_train(self) -> "WeightAdapterTrainBase":
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def calculate_weight(
|
|
||||||
self,
|
|
||||||
weight,
|
|
||||||
key,
|
|
||||||
strength,
|
|
||||||
strength_model,
|
|
||||||
offset,
|
|
||||||
function,
|
|
||||||
intermediate_dtype=torch.float32,
|
|
||||||
original_weight=None,
|
|
||||||
):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class WeightAdapterTrainBase(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# [TODO] Collaborate with LoRA training PR #7032
|
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
|
||||||
lora_diff *= alpha
|
|
||||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
|
||||||
|
|
||||||
wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0]
|
|
||||||
if wd_on_output_axis:
|
|
||||||
weight_norm = (
|
|
||||||
weight.reshape(weight.shape[0], -1)
|
|
||||||
.norm(dim=1, keepdim=True)
|
|
||||||
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
weight_norm = (
|
|
||||||
weight_calc.transpose(0, 1)
|
|
||||||
.reshape(weight_calc.shape[1], -1)
|
|
||||||
.norm(dim=1, keepdim=True)
|
|
||||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
|
||||||
.transpose(0, 1)
|
|
||||||
)
|
|
||||||
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
|
|
||||||
|
|
||||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
|
||||||
if strength != 1.0:
|
|
||||||
weight_calc -= weight
|
|
||||||
weight += strength * (weight_calc)
|
|
||||||
else:
|
|
||||||
weight[:] = weight_calc
|
|
||||||
return weight
|
|
||||||
|
|
||||||
|
|
||||||
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Pad a tensor to a new shape with zeros.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (torch.Tensor): The original tensor to be padded.
|
|
||||||
new_shape (List[int]): The desired shape of the padded tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
If the new shape is smaller than the original tensor in any dimension,
|
|
||||||
the original tensor will be truncated in that dimension.
|
|
||||||
"""
|
|
||||||
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
|
||||||
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
|
||||||
|
|
||||||
if len(new_shape) != len(tensor.shape):
|
|
||||||
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
|
||||||
|
|
||||||
# Create a new tensor filled with zeros
|
|
||||||
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
|
||||||
|
|
||||||
# Create slicing tuples for both tensors
|
|
||||||
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
|
||||||
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
|
||||||
|
|
||||||
# Copy the original tensor into the new tensor
|
|
||||||
padded_tensor[new_slices] = tensor[orig_slices]
|
|
||||||
|
|
||||||
return padded_tensor
|
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import comfy.model_management
|
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
|
||||||
|
|
||||||
|
|
||||||
class BOFTAdapter(WeightAdapterBase):
|
|
||||||
name = "boft"
|
|
||||||
|
|
||||||
def __init__(self, loaded_keys, weights):
|
|
||||||
self.loaded_keys = loaded_keys
|
|
||||||
self.weights = weights
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(
|
|
||||||
cls,
|
|
||||||
x: str,
|
|
||||||
lora: dict[str, torch.Tensor],
|
|
||||||
alpha: float,
|
|
||||||
dora_scale: torch.Tensor,
|
|
||||||
loaded_keys: set[str] = None,
|
|
||||||
) -> Optional["BOFTAdapter"]:
|
|
||||||
if loaded_keys is None:
|
|
||||||
loaded_keys = set()
|
|
||||||
blocks_name = "{}.boft_blocks".format(x)
|
|
||||||
rescale_name = "{}.rescale".format(x)
|
|
||||||
|
|
||||||
blocks = None
|
|
||||||
if blocks_name in lora.keys():
|
|
||||||
blocks = lora[blocks_name]
|
|
||||||
if blocks.ndim == 4:
|
|
||||||
loaded_keys.add(blocks_name)
|
|
||||||
|
|
||||||
rescale = None
|
|
||||||
if rescale_name in lora.keys():
|
|
||||||
rescale = lora[rescale_name]
|
|
||||||
loaded_keys.add(rescale_name)
|
|
||||||
|
|
||||||
if blocks is not None:
|
|
||||||
weights = (blocks, rescale, alpha, dora_scale)
|
|
||||||
return cls(loaded_keys, weights)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def calculate_weight(
|
|
||||||
self,
|
|
||||||
weight,
|
|
||||||
key,
|
|
||||||
strength,
|
|
||||||
strength_model,
|
|
||||||
offset,
|
|
||||||
function,
|
|
||||||
intermediate_dtype=torch.float32,
|
|
||||||
original_weight=None,
|
|
||||||
):
|
|
||||||
v = self.weights
|
|
||||||
blocks = v[0]
|
|
||||||
rescale = v[1]
|
|
||||||
alpha = v[2]
|
|
||||||
dora_scale = v[3]
|
|
||||||
|
|
||||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
|
||||||
if rescale is not None:
|
|
||||||
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
boft_m, block_num, boft_b, *_ = blocks.shape
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get r
|
|
||||||
I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype)
|
|
||||||
# for Q = -Q^T
|
|
||||||
q = blocks - blocks.transpose(1, 2)
|
|
||||||
normed_q = q
|
|
||||||
if alpha > 0: # alpha in boft/bboft is for constraint
|
|
||||||
q_norm = torch.norm(q) + 1e-8
|
|
||||||
if q_norm > alpha:
|
|
||||||
normed_q = q * alpha / q_norm
|
|
||||||
# use float() to prevent unsupported type in .inverse()
|
|
||||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
|
||||||
r = r.to(original_weight)
|
|
||||||
|
|
||||||
inp = org = original_weight
|
|
||||||
|
|
||||||
r_b = boft_b//2
|
|
||||||
for i in range(boft_m):
|
|
||||||
bi = r[i]
|
|
||||||
g = 2
|
|
||||||
k = 2**i * r_b
|
|
||||||
if strength != 1:
|
|
||||||
bi = bi * strength + (1-strength) * I
|
|
||||||
inp = (
|
|
||||||
inp.unflatten(-1, (-1, g, k))
|
|
||||||
.transpose(-2, -1)
|
|
||||||
.flatten(-3)
|
|
||||||
.unflatten(-1, (-1, boft_b))
|
|
||||||
)
|
|
||||||
inp = torch.einsum("b n m, b n ... -> b m ...", inp, bi)
|
|
||||||
inp = (
|
|
||||||
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
|
|
||||||
)
|
|
||||||
|
|
||||||
if rescale is not None:
|
|
||||||
inp = inp * rescale
|
|
||||||
|
|
||||||
lora_diff = inp - org
|
|
||||||
lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
|
||||||
return weight
|
|
||||||
@@ -1,93 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import comfy.model_management
|
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
|
||||||
|
|
||||||
|
|
||||||
class GLoRAAdapter(WeightAdapterBase):
|
|
||||||
name = "glora"
|
|
||||||
|
|
||||||
def __init__(self, loaded_keys, weights):
|
|
||||||
self.loaded_keys = loaded_keys
|
|
||||||
self.weights = weights
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(
|
|
||||||
cls,
|
|
||||||
x: str,
|
|
||||||
lora: dict[str, torch.Tensor],
|
|
||||||
alpha: float,
|
|
||||||
dora_scale: torch.Tensor,
|
|
||||||
loaded_keys: set[str] = None,
|
|
||||||
) -> Optional["GLoRAAdapter"]:
|
|
||||||
if loaded_keys is None:
|
|
||||||
loaded_keys = set()
|
|
||||||
a1_name = "{}.a1.weight".format(x)
|
|
||||||
a2_name = "{}.a2.weight".format(x)
|
|
||||||
b1_name = "{}.b1.weight".format(x)
|
|
||||||
b2_name = "{}.b2.weight".format(x)
|
|
||||||
if a1_name in lora:
|
|
||||||
weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
|
|
||||||
loaded_keys.add(a1_name)
|
|
||||||
loaded_keys.add(a2_name)
|
|
||||||
loaded_keys.add(b1_name)
|
|
||||||
loaded_keys.add(b2_name)
|
|
||||||
return cls(loaded_keys, weights)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def calculate_weight(
|
|
||||||
self,
|
|
||||||
weight,
|
|
||||||
key,
|
|
||||||
strength,
|
|
||||||
strength_model,
|
|
||||||
offset,
|
|
||||||
function,
|
|
||||||
intermediate_dtype=torch.float32,
|
|
||||||
original_weight=None,
|
|
||||||
):
|
|
||||||
v = self.weights
|
|
||||||
dora_scale = v[5]
|
|
||||||
|
|
||||||
old_glora = False
|
|
||||||
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
|
||||||
rank = v[0].shape[0]
|
|
||||||
old_glora = True
|
|
||||||
|
|
||||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
|
||||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
old_glora = False
|
|
||||||
rank = v[1].shape[0]
|
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if v[4] is not None:
|
|
||||||
alpha = v[4] / rank
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
if old_glora:
|
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
|
||||||
else:
|
|
||||||
if weight.dim() > 2:
|
|
||||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
|
||||||
else:
|
|
||||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
|
||||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
|
||||||
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
|
||||||
return weight
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import comfy.model_management
|
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
|
||||||
|
|
||||||
|
|
||||||
class LoHaAdapter(WeightAdapterBase):
|
|
||||||
name = "loha"
|
|
||||||
|
|
||||||
def __init__(self, loaded_keys, weights):
|
|
||||||
self.loaded_keys = loaded_keys
|
|
||||||
self.weights = weights
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(
|
|
||||||
cls,
|
|
||||||
x: str,
|
|
||||||
lora: dict[str, torch.Tensor],
|
|
||||||
alpha: float,
|
|
||||||
dora_scale: torch.Tensor,
|
|
||||||
loaded_keys: set[str] = None,
|
|
||||||
) -> Optional["LoHaAdapter"]:
|
|
||||||
if loaded_keys is None:
|
|
||||||
loaded_keys = set()
|
|
||||||
|
|
||||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
|
||||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
|
||||||
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
|
||||||
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
|
||||||
hada_t1_name = "{}.hada_t1".format(x)
|
|
||||||
hada_t2_name = "{}.hada_t2".format(x)
|
|
||||||
if hada_w1_a_name in lora.keys():
|
|
||||||
hada_t1 = None
|
|
||||||
hada_t2 = None
|
|
||||||
if hada_t1_name in lora.keys():
|
|
||||||
hada_t1 = lora[hada_t1_name]
|
|
||||||
hada_t2 = lora[hada_t2_name]
|
|
||||||
loaded_keys.add(hada_t1_name)
|
|
||||||
loaded_keys.add(hada_t2_name)
|
|
||||||
|
|
||||||
weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)
|
|
||||||
loaded_keys.add(hada_w1_a_name)
|
|
||||||
loaded_keys.add(hada_w1_b_name)
|
|
||||||
loaded_keys.add(hada_w2_a_name)
|
|
||||||
loaded_keys.add(hada_w2_b_name)
|
|
||||||
return cls(loaded_keys, weights)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def calculate_weight(
|
|
||||||
self,
|
|
||||||
weight,
|
|
||||||
key,
|
|
||||||
strength,
|
|
||||||
strength_model,
|
|
||||||
offset,
|
|
||||||
function,
|
|
||||||
intermediate_dtype=torch.float32,
|
|
||||||
original_weight=None,
|
|
||||||
):
|
|
||||||
v = self.weights
|
|
||||||
w1a = v[0]
|
|
||||||
w1b = v[1]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / w1b.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
w2a = v[3]
|
|
||||||
w2b = v[4]
|
|
||||||
dora_scale = v[7]
|
|
||||||
if v[5] is not None: #cp decomposition
|
|
||||||
t1 = v[5]
|
|
||||||
t2 = v[6]
|
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
|
||||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
|
||||||
return weight
|
|
||||||
@@ -1,133 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import comfy.model_management
|
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
|
||||||
|
|
||||||
|
|
||||||
class LoKrAdapter(WeightAdapterBase):
|
|
||||||
name = "lokr"
|
|
||||||
|
|
||||||
def __init__(self, loaded_keys, weights):
|
|
||||||
self.loaded_keys = loaded_keys
|
|
||||||
self.weights = weights
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(
|
|
||||||
cls,
|
|
||||||
x: str,
|
|
||||||
lora: dict[str, torch.Tensor],
|
|
||||||
alpha: float,
|
|
||||||
dora_scale: torch.Tensor,
|
|
||||||
loaded_keys: set[str] = None,
|
|
||||||
) -> Optional["LoKrAdapter"]:
|
|
||||||
if loaded_keys is None:
|
|
||||||
loaded_keys = set()
|
|
||||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
|
||||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
|
||||||
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
|
||||||
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
|
||||||
lokr_t2_name = "{}.lokr_t2".format(x)
|
|
||||||
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
|
||||||
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
|
||||||
|
|
||||||
lokr_w1 = None
|
|
||||||
if lokr_w1_name in lora.keys():
|
|
||||||
lokr_w1 = lora[lokr_w1_name]
|
|
||||||
loaded_keys.add(lokr_w1_name)
|
|
||||||
|
|
||||||
lokr_w2 = None
|
|
||||||
if lokr_w2_name in lora.keys():
|
|
||||||
lokr_w2 = lora[lokr_w2_name]
|
|
||||||
loaded_keys.add(lokr_w2_name)
|
|
||||||
|
|
||||||
lokr_w1_a = None
|
|
||||||
if lokr_w1_a_name in lora.keys():
|
|
||||||
lokr_w1_a = lora[lokr_w1_a_name]
|
|
||||||
loaded_keys.add(lokr_w1_a_name)
|
|
||||||
|
|
||||||
lokr_w1_b = None
|
|
||||||
if lokr_w1_b_name in lora.keys():
|
|
||||||
lokr_w1_b = lora[lokr_w1_b_name]
|
|
||||||
loaded_keys.add(lokr_w1_b_name)
|
|
||||||
|
|
||||||
lokr_w2_a = None
|
|
||||||
if lokr_w2_a_name in lora.keys():
|
|
||||||
lokr_w2_a = lora[lokr_w2_a_name]
|
|
||||||
loaded_keys.add(lokr_w2_a_name)
|
|
||||||
|
|
||||||
lokr_w2_b = None
|
|
||||||
if lokr_w2_b_name in lora.keys():
|
|
||||||
lokr_w2_b = lora[lokr_w2_b_name]
|
|
||||||
loaded_keys.add(lokr_w2_b_name)
|
|
||||||
|
|
||||||
lokr_t2 = None
|
|
||||||
if lokr_t2_name in lora.keys():
|
|
||||||
lokr_t2 = lora[lokr_t2_name]
|
|
||||||
loaded_keys.add(lokr_t2_name)
|
|
||||||
|
|
||||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
|
||||||
weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
|
|
||||||
return cls(loaded_keys, weights)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def calculate_weight(
|
|
||||||
self,
|
|
||||||
weight,
|
|
||||||
key,
|
|
||||||
strength,
|
|
||||||
strength_model,
|
|
||||||
offset,
|
|
||||||
function,
|
|
||||||
intermediate_dtype=torch.float32,
|
|
||||||
original_weight=None,
|
|
||||||
):
|
|
||||||
v = self.weights
|
|
||||||
w1 = v[0]
|
|
||||||
w2 = v[1]
|
|
||||||
w1_a = v[3]
|
|
||||||
w1_b = v[4]
|
|
||||||
w2_a = v[5]
|
|
||||||
w2_b = v[6]
|
|
||||||
t2 = v[7]
|
|
||||||
dora_scale = v[8]
|
|
||||||
dim = None
|
|
||||||
|
|
||||||
if w1 is None:
|
|
||||||
dim = w1_b.shape[0]
|
|
||||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if w2 is None:
|
|
||||||
dim = w2_b.shape[0]
|
|
||||||
if t2 is None:
|
|
||||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
if v[2] is not None and dim is not None:
|
|
||||||
alpha = v[2] / dim
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
|
||||||
return weight
|
|
||||||
@@ -1,142 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import comfy.model_management
|
|
||||||
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAAdapter(WeightAdapterBase):
|
|
||||||
name = "lora"
|
|
||||||
|
|
||||||
def __init__(self, loaded_keys, weights):
|
|
||||||
self.loaded_keys = loaded_keys
|
|
||||||
self.weights = weights
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(
|
|
||||||
cls,
|
|
||||||
x: str,
|
|
||||||
lora: dict[str, torch.Tensor],
|
|
||||||
alpha: float,
|
|
||||||
dora_scale: torch.Tensor,
|
|
||||||
loaded_keys: set[str] = None,
|
|
||||||
) -> Optional["LoRAAdapter"]:
|
|
||||||
if loaded_keys is None:
|
|
||||||
loaded_keys = set()
|
|
||||||
|
|
||||||
reshape_name = "{}.reshape_weight".format(x)
|
|
||||||
regular_lora = "{}.lora_up.weight".format(x)
|
|
||||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
|
||||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
|
||||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
|
||||||
mochi_lora = "{}.lora_B".format(x)
|
|
||||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
|
||||||
A_name = None
|
|
||||||
|
|
||||||
if regular_lora in lora.keys():
|
|
||||||
A_name = regular_lora
|
|
||||||
B_name = "{}.lora_down.weight".format(x)
|
|
||||||
mid_name = "{}.lora_mid.weight".format(x)
|
|
||||||
elif diffusers_lora in lora.keys():
|
|
||||||
A_name = diffusers_lora
|
|
||||||
B_name = "{}_lora.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif diffusers2_lora in lora.keys():
|
|
||||||
A_name = diffusers2_lora
|
|
||||||
B_name = "{}.lora_A.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif diffusers3_lora in lora.keys():
|
|
||||||
A_name = diffusers3_lora
|
|
||||||
B_name = "{}.lora.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif mochi_lora in lora.keys():
|
|
||||||
A_name = mochi_lora
|
|
||||||
B_name = "{}.lora_A".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif transformers_lora in lora.keys():
|
|
||||||
A_name = transformers_lora
|
|
||||||
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
|
|
||||||
if A_name is not None:
|
|
||||||
mid = None
|
|
||||||
if mid_name is not None and mid_name in lora.keys():
|
|
||||||
mid = lora[mid_name]
|
|
||||||
loaded_keys.add(mid_name)
|
|
||||||
reshape = None
|
|
||||||
if reshape_name in lora.keys():
|
|
||||||
try:
|
|
||||||
reshape = lora[reshape_name].tolist()
|
|
||||||
loaded_keys.add(reshape_name)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)
|
|
||||||
loaded_keys.add(A_name)
|
|
||||||
loaded_keys.add(B_name)
|
|
||||||
return cls(loaded_keys, weights)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def calculate_weight(
|
|
||||||
self,
|
|
||||||
weight,
|
|
||||||
key,
|
|
||||||
strength,
|
|
||||||
strength_model,
|
|
||||||
offset,
|
|
||||||
function,
|
|
||||||
intermediate_dtype=torch.float32,
|
|
||||||
original_weight=None,
|
|
||||||
):
|
|
||||||
v = self.weights
|
|
||||||
mat1 = comfy.model_management.cast_to_device(
|
|
||||||
v[0], weight.device, intermediate_dtype
|
|
||||||
)
|
|
||||||
mat2 = comfy.model_management.cast_to_device(
|
|
||||||
v[1], weight.device, intermediate_dtype
|
|
||||||
)
|
|
||||||
dora_scale = v[4]
|
|
||||||
reshape = v[5]
|
|
||||||
|
|
||||||
if reshape is not None:
|
|
||||||
weight = pad_tensor_to_shape(weight, reshape)
|
|
||||||
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / mat2.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
if v[3] is not None:
|
|
||||||
# locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
||||||
mat3 = comfy.model_management.cast_to_device(
|
|
||||||
v[3], weight.device, intermediate_dtype
|
|
||||||
)
|
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
||||||
mat2 = (
|
|
||||||
torch.mm(
|
|
||||||
mat2.transpose(0, 1).flatten(start_dim=1),
|
|
||||||
mat3.transpose(0, 1).flatten(start_dim=1),
|
|
||||||
)
|
|
||||||
.reshape(final_shape)
|
|
||||||
.transpose(0, 1)
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
lora_diff = torch.mm(
|
|
||||||
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
|
|
||||||
).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(
|
|
||||||
dora_scale,
|
|
||||||
weight,
|
|
||||||
lora_diff,
|
|
||||||
alpha,
|
|
||||||
strength,
|
|
||||||
intermediate_dtype,
|
|
||||||
function,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
|
||||||
return weight
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import comfy.model_management
|
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
|
||||||
|
|
||||||
|
|
||||||
class OFTAdapter(WeightAdapterBase):
|
|
||||||
name = "oft"
|
|
||||||
|
|
||||||
def __init__(self, loaded_keys, weights):
|
|
||||||
self.loaded_keys = loaded_keys
|
|
||||||
self.weights = weights
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(
|
|
||||||
cls,
|
|
||||||
x: str,
|
|
||||||
lora: dict[str, torch.Tensor],
|
|
||||||
alpha: float,
|
|
||||||
dora_scale: torch.Tensor,
|
|
||||||
loaded_keys: set[str] = None,
|
|
||||||
) -> Optional["OFTAdapter"]:
|
|
||||||
if loaded_keys is None:
|
|
||||||
loaded_keys = set()
|
|
||||||
blocks_name = "{}.oft_blocks".format(x)
|
|
||||||
rescale_name = "{}.rescale".format(x)
|
|
||||||
|
|
||||||
blocks = None
|
|
||||||
if blocks_name in lora.keys():
|
|
||||||
blocks = lora[blocks_name]
|
|
||||||
if blocks.ndim == 3:
|
|
||||||
loaded_keys.add(blocks_name)
|
|
||||||
|
|
||||||
rescale = None
|
|
||||||
if rescale_name in lora.keys():
|
|
||||||
rescale = lora[rescale_name]
|
|
||||||
loaded_keys.add(rescale_name)
|
|
||||||
|
|
||||||
if blocks is not None:
|
|
||||||
weights = (blocks, rescale, alpha, dora_scale)
|
|
||||||
return cls(loaded_keys, weights)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def calculate_weight(
|
|
||||||
self,
|
|
||||||
weight,
|
|
||||||
key,
|
|
||||||
strength,
|
|
||||||
strength_model,
|
|
||||||
offset,
|
|
||||||
function,
|
|
||||||
intermediate_dtype=torch.float32,
|
|
||||||
original_weight=None,
|
|
||||||
):
|
|
||||||
v = self.weights
|
|
||||||
blocks = v[0]
|
|
||||||
rescale = v[1]
|
|
||||||
alpha = v[2]
|
|
||||||
dora_scale = v[3]
|
|
||||||
|
|
||||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
|
||||||
if rescale is not None:
|
|
||||||
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
block_num, block_size, *_ = blocks.shape
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get r
|
|
||||||
I = torch.eye(block_size, device=blocks.device, dtype=blocks.dtype)
|
|
||||||
# for Q = -Q^T
|
|
||||||
q = blocks - blocks.transpose(1, 2)
|
|
||||||
normed_q = q
|
|
||||||
if alpha > 0: # alpha in oft/boft is for constraint
|
|
||||||
q_norm = torch.norm(q) + 1e-8
|
|
||||||
if q_norm > alpha:
|
|
||||||
normed_q = q * alpha / q_norm
|
|
||||||
# use float() to prevent unsupported type in .inverse()
|
|
||||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
|
||||||
r = r.to(original_weight)
|
|
||||||
lora_diff = torch.einsum(
|
|
||||||
"k n m, k n ... -> k m ...",
|
|
||||||
(r * strength) - strength * I,
|
|
||||||
original_weight,
|
|
||||||
)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
|
||||||
return weight
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
# generated by datamodel-codegen:
|
|
||||||
# filename: https://api.comfy.org/openapi
|
|
||||||
# timestamp: 2025-04-23T15:56:33+00:00
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from . import PixverseDto
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseData(BaseModel):
|
|
||||||
ErrCode: Optional[int] = None
|
|
||||||
ErrMsg: Optional[str] = None
|
|
||||||
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
# generated by datamodel-codegen:
|
|
||||||
# filename: https://api.comfy.org/openapi
|
|
||||||
# timestamp: 2025-04-23T15:56:33+00:00
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, constr
|
|
||||||
|
|
||||||
|
|
||||||
class V2OpenAPII2VResp(BaseModel):
|
|
||||||
video_id: Optional[int] = Field(None, description='Video_id')
|
|
||||||
|
|
||||||
|
|
||||||
class V2OpenAPIT2VReq(BaseModel):
|
|
||||||
aspect_ratio: str = Field(
|
|
||||||
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
|
|
||||||
)
|
|
||||||
duration: int = Field(
|
|
||||||
...,
|
|
||||||
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
|
|
||||||
examples=[5],
|
|
||||||
)
|
|
||||||
model: str = Field(
|
|
||||||
..., description='Model version (only supports v3.5)', examples=['v3.5']
|
|
||||||
)
|
|
||||||
motion_mode: Optional[str] = Field(
|
|
||||||
'normal',
|
|
||||||
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
|
|
||||||
examples=['normal'],
|
|
||||||
)
|
|
||||||
negative_prompt: Optional[constr(max_length=2048)] = Field(
|
|
||||||
None, description='Negative prompt\n'
|
|
||||||
)
|
|
||||||
prompt: constr(max_length=2048) = Field(..., description='Prompt')
|
|
||||||
quality: str = Field(
|
|
||||||
...,
|
|
||||||
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
|
|
||||||
examples=['540p'],
|
|
||||||
)
|
|
||||||
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
|
|
||||||
style: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
|
|
||||||
examples=['anime'],
|
|
||||||
)
|
|
||||||
template_id: Optional[int] = Field(
|
|
||||||
None,
|
|
||||||
description='Template ID (template_id must be activated before use)',
|
|
||||||
examples=[302325299692608],
|
|
||||||
)
|
|
||||||
water_mark: Optional[bool] = Field(
|
|
||||||
False,
|
|
||||||
description='Watermark (true: add watermark, false: no watermark)',
|
|
||||||
examples=[False],
|
|
||||||
)
|
|
||||||
@@ -1,422 +0,0 @@
|
|||||||
# generated by datamodel-codegen:
|
|
||||||
# filename: https://api.comfy.org/openapi
|
|
||||||
# timestamp: 2025-04-23T15:56:33+00:00
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from pydantic import AnyUrl, BaseModel, Field, confloat, conint
|
|
||||||
|
|
||||||
class Customer(BaseModel):
|
|
||||||
createdAt: Optional[datetime] = Field(
|
|
||||||
None, description='The date and time the user was created'
|
|
||||||
)
|
|
||||||
email: Optional[str] = Field(None, description='The email address for this user')
|
|
||||||
id: str = Field(..., description='The firebase UID of the user')
|
|
||||||
name: Optional[str] = Field(None, description='The name for this user')
|
|
||||||
updatedAt: Optional[datetime] = Field(
|
|
||||||
None, description='The date and time the user was last updated'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Error(BaseModel):
|
|
||||||
details: Optional[List[str]] = Field(
|
|
||||||
None,
|
|
||||||
description='Optional detailed information about the error or hints for resolving it.',
|
|
||||||
)
|
|
||||||
message: Optional[str] = Field(
|
|
||||||
None, description='A clear and concise description of the error.'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorResponse(BaseModel):
|
|
||||||
error: str
|
|
||||||
message: str
|
|
||||||
|
|
||||||
class ImageRequest(BaseModel):
|
|
||||||
aspect_ratio: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.",
|
|
||||||
)
|
|
||||||
color_palette: Optional[Dict[str, Any]] = Field(
|
|
||||||
None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.'
|
|
||||||
)
|
|
||||||
magic_prompt_option: Optional[str] = Field(
|
|
||||||
None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')."
|
|
||||||
)
|
|
||||||
model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')")
|
|
||||||
negative_prompt: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.',
|
|
||||||
)
|
|
||||||
num_images: Optional[conint(ge=1, le=8)] = Field(
|
|
||||||
1, description='Optional. Number of images to generate (1-8). Defaults to 1.'
|
|
||||||
)
|
|
||||||
prompt: str = Field(
|
|
||||||
..., description='Required. The prompt to use to generate the image.'
|
|
||||||
)
|
|
||||||
resolution: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.",
|
|
||||||
)
|
|
||||||
seed: Optional[conint(ge=0, le=2147483647)] = Field(
|
|
||||||
None, description='Optional. A number between 0 and 2147483647.'
|
|
||||||
)
|
|
||||||
style_type: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Datum(BaseModel):
|
|
||||||
is_image_safe: Optional[bool] = Field(
|
|
||||||
None, description='Indicates whether the image is considered safe.'
|
|
||||||
)
|
|
||||||
prompt: Optional[str] = Field(
|
|
||||||
None, description='The prompt used to generate this image.'
|
|
||||||
)
|
|
||||||
resolution: Optional[str] = Field(
|
|
||||||
None, description="The resolution of the generated image (e.g., '1024x1024')."
|
|
||||||
)
|
|
||||||
seed: Optional[int] = Field(
|
|
||||||
None, description='The seed value used for this generation.'
|
|
||||||
)
|
|
||||||
style_type: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description="The style type used for generation (e.g., 'REALISTIC', 'ANIME').",
|
|
||||||
)
|
|
||||||
url: Optional[str] = Field(None, description='URL to the generated image.')
|
|
||||||
|
|
||||||
|
|
||||||
class Code(Enum):
|
|
||||||
int_1100 = 1100
|
|
||||||
int_1101 = 1101
|
|
||||||
int_1102 = 1102
|
|
||||||
int_1103 = 1103
|
|
||||||
|
|
||||||
|
|
||||||
class Code1(Enum):
|
|
||||||
int_1000 = 1000
|
|
||||||
int_1001 = 1001
|
|
||||||
int_1002 = 1002
|
|
||||||
int_1003 = 1003
|
|
||||||
int_1004 = 1004
|
|
||||||
|
|
||||||
|
|
||||||
class AspectRatio(str, Enum):
|
|
||||||
field_16_9 = '16:9'
|
|
||||||
field_9_16 = '9:16'
|
|
||||||
field_1_1 = '1:1'
|
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
|
||||||
horizontal: Optional[confloat(ge=-10.0, le=10.0)] = None
|
|
||||||
pan: Optional[confloat(ge=-10.0, le=10.0)] = None
|
|
||||||
roll: Optional[confloat(ge=-10.0, le=10.0)] = None
|
|
||||||
tilt: Optional[confloat(ge=-10.0, le=10.0)] = None
|
|
||||||
vertical: Optional[confloat(ge=-10.0, le=10.0)] = None
|
|
||||||
zoom: Optional[confloat(ge=-10.0, le=10.0)] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Type(str, Enum):
|
|
||||||
simple = 'simple'
|
|
||||||
down_back = 'down_back'
|
|
||||||
forward_up = 'forward_up'
|
|
||||||
right_turn_forward = 'right_turn_forward'
|
|
||||||
left_turn_forward = 'left_turn_forward'
|
|
||||||
|
|
||||||
|
|
||||||
class CameraControl(BaseModel):
|
|
||||||
config: Optional[Config] = None
|
|
||||||
type: Optional[Type] = Field(None, description='Predefined camera movements type')
|
|
||||||
|
|
||||||
|
|
||||||
class Duration(str, Enum):
|
|
||||||
field_5 = 5
|
|
||||||
field_10 = 10
|
|
||||||
|
|
||||||
|
|
||||||
class Mode(str, Enum):
|
|
||||||
std = 'std'
|
|
||||||
pro = 'pro'
|
|
||||||
|
|
||||||
|
|
||||||
class TaskInfo(BaseModel):
|
|
||||||
external_task_id: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Video(BaseModel):
|
|
||||||
duration: Optional[str] = Field(None, description='Total video duration')
|
|
||||||
id: Optional[str] = Field(None, description='Generated video ID')
|
|
||||||
url: Optional[AnyUrl] = Field(None, description='URL for generated video')
|
|
||||||
|
|
||||||
|
|
||||||
class TaskResult(BaseModel):
|
|
||||||
videos: Optional[List[Video]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(str, Enum):
|
|
||||||
submitted = 'submitted'
|
|
||||||
processing = 'processing'
|
|
||||||
succeed = 'succeed'
|
|
||||||
failed = 'failed'
|
|
||||||
|
|
||||||
|
|
||||||
class Data(BaseModel):
|
|
||||||
created_at: Optional[int] = Field(None, description='Task creation time')
|
|
||||||
task_id: Optional[str] = Field(None, description='Task ID')
|
|
||||||
task_info: Optional[TaskInfo] = None
|
|
||||||
task_result: Optional[TaskResult] = None
|
|
||||||
task_status: Optional[TaskStatus] = None
|
|
||||||
updated_at: Optional[int] = Field(None, description='Task update time')
|
|
||||||
|
|
||||||
|
|
||||||
class AspectRatio1(str, Enum):
|
|
||||||
field_16_9 = '16:9'
|
|
||||||
field_9_16 = '9:16'
|
|
||||||
field_1_1 = '1:1'
|
|
||||||
field_4_3 = '4:3'
|
|
||||||
field_3_4 = '3:4'
|
|
||||||
field_3_2 = '3:2'
|
|
||||||
field_2_3 = '2:3'
|
|
||||||
field_21_9 = '21:9'
|
|
||||||
|
|
||||||
|
|
||||||
class ImageReference(str, Enum):
|
|
||||||
subject = 'subject'
|
|
||||||
face = 'face'
|
|
||||||
|
|
||||||
|
|
||||||
class Image(BaseModel):
|
|
||||||
index: Optional[int] = Field(None, description='Image Number (0-9)')
|
|
||||||
url: Optional[AnyUrl] = Field(None, description='URL for generated image')
|
|
||||||
|
|
||||||
|
|
||||||
class TaskResult1(BaseModel):
|
|
||||||
images: Optional[List[Image]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Data1(BaseModel):
|
|
||||||
created_at: Optional[int] = Field(None, description='Task creation time')
|
|
||||||
task_id: Optional[str] = Field(None, description='Task ID')
|
|
||||||
task_result: Optional[TaskResult1] = None
|
|
||||||
task_status: Optional[TaskStatus] = None
|
|
||||||
task_status_msg: Optional[str] = Field(None, description='Task status information')
|
|
||||||
updated_at: Optional[int] = Field(None, description='Task update time')
|
|
||||||
|
|
||||||
|
|
||||||
class AspectRatio2(str, Enum):
|
|
||||||
field_16_9 = '16:9'
|
|
||||||
field_9_16 = '9:16'
|
|
||||||
field_1_1 = '1:1'
|
|
||||||
|
|
||||||
|
|
||||||
class CameraControl1(BaseModel):
|
|
||||||
config: Optional[Config] = None
|
|
||||||
type: Optional[Type] = Field(None, description='Predefined camera movements type')
|
|
||||||
|
|
||||||
|
|
||||||
class ModelName2(str, Enum):
|
|
||||||
kling_v1 = 'kling-v1'
|
|
||||||
kling_v1_6 = 'kling-v1-6'
|
|
||||||
|
|
||||||
|
|
||||||
class TaskResult2(BaseModel):
|
|
||||||
videos: Optional[List[Video]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Data2(BaseModel):
|
|
||||||
created_at: Optional[int] = Field(None, description='Task creation time')
|
|
||||||
task_id: Optional[str] = Field(None, description='Task ID')
|
|
||||||
task_info: Optional[TaskInfo] = None
|
|
||||||
task_result: Optional[TaskResult2] = None
|
|
||||||
task_status: Optional[TaskStatus] = None
|
|
||||||
updated_at: Optional[int] = Field(None, description='Task update time')
|
|
||||||
|
|
||||||
|
|
||||||
class Code2(Enum):
|
|
||||||
int_1200 = 1200
|
|
||||||
int_1201 = 1201
|
|
||||||
int_1202 = 1202
|
|
||||||
int_1203 = 1203
|
|
||||||
|
|
||||||
|
|
||||||
class ResourcePackType(str, Enum):
|
|
||||||
decreasing_total = 'decreasing_total'
|
|
||||||
constant_period = 'constant_period'
|
|
||||||
|
|
||||||
|
|
||||||
class Status(str, Enum):
|
|
||||||
toBeOnline = 'toBeOnline'
|
|
||||||
online = 'online'
|
|
||||||
expired = 'expired'
|
|
||||||
runOut = 'runOut'
|
|
||||||
|
|
||||||
|
|
||||||
class ResourcePackSubscribeInfo(BaseModel):
|
|
||||||
effective_time: Optional[int] = Field(
|
|
||||||
None, description='Effective time, Unix timestamp in ms'
|
|
||||||
)
|
|
||||||
invalid_time: Optional[int] = Field(
|
|
||||||
None, description='Expiration time, Unix timestamp in ms'
|
|
||||||
)
|
|
||||||
purchase_time: Optional[int] = Field(
|
|
||||||
None, description='Purchase time, Unix timestamp in ms'
|
|
||||||
)
|
|
||||||
remaining_quantity: Optional[float] = Field(
|
|
||||||
None, description='Remaining quantity (updated with a 12-hour delay)'
|
|
||||||
)
|
|
||||||
resource_pack_id: Optional[str] = Field(None, description='Resource package ID')
|
|
||||||
resource_pack_name: Optional[str] = Field(None, description='Resource package name')
|
|
||||||
resource_pack_type: Optional[ResourcePackType] = Field(
|
|
||||||
None,
|
|
||||||
description='Resource package type (decreasing_total=decreasing total, constant_period=constant periodicity)',
|
|
||||||
)
|
|
||||||
status: Optional[Status] = Field(None, description='Resource Package Status')
|
|
||||||
total_quantity: Optional[float] = Field(None, description='Total quantity')
|
|
||||||
|
|
||||||
class Background(str, Enum):
|
|
||||||
transparent = 'transparent'
|
|
||||||
opaque = 'opaque'
|
|
||||||
|
|
||||||
|
|
||||||
class Moderation(str, Enum):
|
|
||||||
low = 'low'
|
|
||||||
auto = 'auto'
|
|
||||||
|
|
||||||
|
|
||||||
class OutputFormat(str, Enum):
|
|
||||||
png = 'png'
|
|
||||||
webp = 'webp'
|
|
||||||
jpeg = 'jpeg'
|
|
||||||
|
|
||||||
|
|
||||||
class Quality(str, Enum):
|
|
||||||
low = 'low'
|
|
||||||
medium = 'medium'
|
|
||||||
high = 'high'
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIImageEditRequest(BaseModel):
|
|
||||||
background: Optional[str] = Field(
|
|
||||||
None, description='Background transparency', examples=['opaque']
|
|
||||||
)
|
|
||||||
model: str = Field(
|
|
||||||
..., description='The model to use for image editing', examples=['gpt-image-1']
|
|
||||||
)
|
|
||||||
moderation: Optional[Moderation] = Field(
|
|
||||||
None, description='Content moderation setting', examples=['auto']
|
|
||||||
)
|
|
||||||
n: Optional[int] = Field(
|
|
||||||
None, description='The number of images to generate', examples=[1]
|
|
||||||
)
|
|
||||||
output_compression: Optional[int] = Field(
|
|
||||||
None, description='Compression level for JPEG or WebP (0-100)', examples=[100]
|
|
||||||
)
|
|
||||||
output_format: Optional[OutputFormat] = Field(
|
|
||||||
None, description='Format of the output image', examples=['png']
|
|
||||||
)
|
|
||||||
prompt: str = Field(
|
|
||||||
...,
|
|
||||||
description='A text description of the desired edit',
|
|
||||||
examples=['Give the rocketship rainbow coloring'],
|
|
||||||
)
|
|
||||||
quality: Optional[str] = Field(
|
|
||||||
None, description='The quality of the edited image', examples=['low']
|
|
||||||
)
|
|
||||||
size: Optional[str] = Field(
|
|
||||||
None, description='Size of the output image', examples=['1024x1024']
|
|
||||||
)
|
|
||||||
user: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description='A unique identifier for end-user monitoring',
|
|
||||||
examples=['user-1234'],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Quality1(str, Enum):
|
|
||||||
low = 'low'
|
|
||||||
medium = 'medium'
|
|
||||||
high = 'high'
|
|
||||||
standard = 'standard'
|
|
||||||
hd = 'hd'
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormat(str, Enum):
|
|
||||||
url = 'url'
|
|
||||||
b64_json = 'b64_json'
|
|
||||||
|
|
||||||
|
|
||||||
class Style(str, Enum):
|
|
||||||
vivid = 'vivid'
|
|
||||||
natural = 'natural'
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIImageGenerationRequest(BaseModel):
|
|
||||||
background: Optional[Background] = Field(
|
|
||||||
None, description='Background transparency', examples=['opaque']
|
|
||||||
)
|
|
||||||
model: Optional[str] = Field(
|
|
||||||
None, description='The model to use for image generation', examples=['dall-e-3']
|
|
||||||
)
|
|
||||||
moderation: Optional[Moderation] = Field(
|
|
||||||
None, description='Content moderation setting', examples=['auto']
|
|
||||||
)
|
|
||||||
n: Optional[int] = Field(
|
|
||||||
None,
|
|
||||||
description='The number of images to generate (1-10). Only 1 supported for dall-e-3.',
|
|
||||||
examples=[1],
|
|
||||||
)
|
|
||||||
output_compression: Optional[int] = Field(
|
|
||||||
None, description='Compression level for JPEG or WebP (0-100)', examples=[100]
|
|
||||||
)
|
|
||||||
output_format: Optional[OutputFormat] = Field(
|
|
||||||
None, description='Format of the output image', examples=['png']
|
|
||||||
)
|
|
||||||
prompt: str = Field(
|
|
||||||
...,
|
|
||||||
description='A text description of the desired image',
|
|
||||||
examples=['Draw a rocket in front of a blackhole in deep space'],
|
|
||||||
)
|
|
||||||
quality: Optional[Quality1] = Field(
|
|
||||||
None, description='The quality of the generated image', examples=['high']
|
|
||||||
)
|
|
||||||
response_format: Optional[ResponseFormat] = Field(
|
|
||||||
None, description='Response format of image data', examples=['b64_json']
|
|
||||||
)
|
|
||||||
size: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description='Size of the image (e.g., 1024x1024, 1536x1024, auto)',
|
|
||||||
examples=['1024x1536'],
|
|
||||||
)
|
|
||||||
style: Optional[Style] = Field(
|
|
||||||
None, description='Style of the image (only for dall-e-3)', examples=['vivid']
|
|
||||||
)
|
|
||||||
user: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description='A unique identifier for end-user monitoring',
|
|
||||||
examples=['user-1234'],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Datum1(BaseModel):
|
|
||||||
b64_json: Optional[str] = Field(None, description='Base64 encoded image data')
|
|
||||||
revised_prompt: Optional[str] = Field(None, description='Revised prompt')
|
|
||||||
url: Optional[str] = Field(None, description='URL of the image')
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIImageGenerationResponse(BaseModel):
|
|
||||||
data: Optional[List[Datum1]] = None
|
|
||||||
class User(BaseModel):
|
|
||||||
email: Optional[str] = Field(None, description='The email address for this user.')
|
|
||||||
id: Optional[str] = Field(None, description='The unique id for this user.')
|
|
||||||
isAdmin: Optional[bool] = Field(
|
|
||||||
None, description='Indicates if the user has admin privileges.'
|
|
||||||
)
|
|
||||||
isApproved: Optional[bool] = Field(
|
|
||||||
None, description='Indicates if the user is approved.'
|
|
||||||
)
|
|
||||||
name: Optional[str] = Field(None, description='The name for this user.')
|
|
||||||
@@ -1,337 +0,0 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
"""
|
|
||||||
API Client Framework for api.comfy.org.
|
|
||||||
|
|
||||||
This module provides a flexible framework for making API requests from ComfyUI nodes.
|
|
||||||
It supports both synchronous and asynchronous API operations with proper type validation.
|
|
||||||
|
|
||||||
Key Components:
|
|
||||||
--------------
|
|
||||||
1. ApiClient - Handles HTTP requests with authentication and error handling
|
|
||||||
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
|
|
||||||
3. ApiOperation - Executes a single synchronous API operation
|
|
||||||
|
|
||||||
Usage Examples:
|
|
||||||
--------------
|
|
||||||
|
|
||||||
# Example 1: Synchronous API Operation
|
|
||||||
# ------------------------------------
|
|
||||||
# For a simple API call that returns the result immediately:
|
|
||||||
|
|
||||||
# 1. Create the API client
|
|
||||||
api_client = ApiClient(
|
|
||||||
base_url="https://api.example.com",
|
|
||||||
api_key="your_api_key_here",
|
|
||||||
timeout=30.0,
|
|
||||||
verify_ssl=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Define the endpoint
|
|
||||||
user_info_endpoint = ApiEndpoint(
|
|
||||||
path="/v1/users/me",
|
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest, # No request body needed
|
|
||||||
response_model=UserProfile, # Pydantic model for the response
|
|
||||||
query_params=None
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Create the request object
|
|
||||||
request = EmptyRequest()
|
|
||||||
|
|
||||||
# 4. Create and execute the operation
|
|
||||||
operation = ApiOperation(
|
|
||||||
endpoint=user_info_endpoint,
|
|
||||||
request=request
|
|
||||||
)
|
|
||||||
user_profile = operation.execute(client=api_client) # Returns immediately with the result
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import (
|
|
||||||
Dict,
|
|
||||||
Type,
|
|
||||||
Optional,
|
|
||||||
Any,
|
|
||||||
TypeVar,
|
|
||||||
Generic,
|
|
||||||
)
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from enum import Enum
|
|
||||||
import json
|
|
||||||
import requests
|
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
|
||||||
R = TypeVar("R", bound=BaseModel)
|
|
||||||
|
|
||||||
class EmptyRequest(BaseModel):
|
|
||||||
"""Base class for empty request bodies.
|
|
||||||
For GET requests, fields will be sent as query parameters."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class HttpMethod(str, Enum):
|
|
||||||
GET = "GET"
|
|
||||||
POST = "POST"
|
|
||||||
PUT = "PUT"
|
|
||||||
DELETE = "DELETE"
|
|
||||||
PATCH = "PATCH"
|
|
||||||
|
|
||||||
|
|
||||||
class ApiClient:
|
|
||||||
"""
|
|
||||||
Client for making HTTP requests to an API with authentication and error handling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
base_url: str,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
timeout: float = 30.0,
|
|
||||||
verify_ssl: bool = True,
|
|
||||||
):
|
|
||||||
self.base_url = base_url
|
|
||||||
self.api_key = api_key
|
|
||||||
self.timeout = timeout
|
|
||||||
self.verify_ssl = verify_ssl
|
|
||||||
|
|
||||||
def get_headers(self) -> Dict[str, str]:
|
|
||||||
"""Get headers for API requests, including authentication if available"""
|
|
||||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
|
||||||
|
|
||||||
if self.api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
||||||
|
|
||||||
return headers
|
|
||||||
|
|
||||||
def request(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
path: str,
|
|
||||||
params: Optional[Dict[str, Any]] = None,
|
|
||||||
json: Optional[Dict[str, Any]] = None,
|
|
||||||
files: Optional[Dict[str, Any]] = None,
|
|
||||||
headers: Optional[Dict[str, str]] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Make an HTTP request to the API
|
|
||||||
|
|
||||||
Args:
|
|
||||||
method: HTTP method (GET, POST, etc.)
|
|
||||||
path: API endpoint path (will be joined with base_url)
|
|
||||||
params: Query parameters
|
|
||||||
json: JSON body data
|
|
||||||
files: Files to upload
|
|
||||||
headers: Additional headers
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Parsed JSON response
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
requests.RequestException: If the request fails
|
|
||||||
"""
|
|
||||||
url = urljoin(self.base_url, path)
|
|
||||||
self.check_auth_token(self.api_key)
|
|
||||||
# Combine default headers with any provided headers
|
|
||||||
request_headers = self.get_headers()
|
|
||||||
if headers:
|
|
||||||
request_headers.update(headers)
|
|
||||||
|
|
||||||
# Let requests handle the content type when files are present.
|
|
||||||
if files:
|
|
||||||
del request_headers["Content-Type"]
|
|
||||||
|
|
||||||
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
|
|
||||||
logging.debug(f"[DEBUG] Files: {files}")
|
|
||||||
logging.debug(f"[DEBUG] Params: {params}")
|
|
||||||
logging.debug(f"[DEBUG] Json: {json}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# If files are present, use data parameter instead of json
|
|
||||||
if files:
|
|
||||||
form_data = {}
|
|
||||||
if json:
|
|
||||||
form_data.update(json)
|
|
||||||
response = requests.request(
|
|
||||||
method=method,
|
|
||||||
url=url,
|
|
||||||
params=params,
|
|
||||||
data=form_data, # Use data instead of json
|
|
||||||
files=files,
|
|
||||||
headers=request_headers,
|
|
||||||
timeout=self.timeout,
|
|
||||||
verify=self.verify_ssl,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = requests.request(
|
|
||||||
method=method,
|
|
||||||
url=url,
|
|
||||||
params=params,
|
|
||||||
json=json,
|
|
||||||
headers=request_headers,
|
|
||||||
timeout=self.timeout,
|
|
||||||
verify=self.verify_ssl,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Raise exception for error status codes
|
|
||||||
response.raise_for_status()
|
|
||||||
except requests.ConnectionError:
|
|
||||||
raise Exception(
|
|
||||||
f"Unable to connect to the API server at {self.base_url}. Please check your internet connection or verify the service is available."
|
|
||||||
)
|
|
||||||
|
|
||||||
except requests.Timeout:
|
|
||||||
raise Exception(
|
|
||||||
f"Request timed out after {self.timeout} seconds. The server might be experiencing high load or the operation is taking longer than expected."
|
|
||||||
)
|
|
||||||
|
|
||||||
except requests.HTTPError as e:
|
|
||||||
status_code = e.response.status_code if hasattr(e, "response") else None
|
|
||||||
error_message = f"HTTP Error: {str(e)}"
|
|
||||||
|
|
||||||
# Try to extract detailed error message from JSON response
|
|
||||||
try:
|
|
||||||
if hasattr(e, "response") and e.response.content:
|
|
||||||
error_json = e.response.json()
|
|
||||||
if "error" in error_json and "message" in error_json["error"]:
|
|
||||||
error_message = f"API Error: {error_json['error']['message']}"
|
|
||||||
if "type" in error_json["error"]:
|
|
||||||
error_message += f" (Type: {error_json['error']['type']})"
|
|
||||||
else:
|
|
||||||
error_message = f"API Error: {error_json}"
|
|
||||||
except Exception as json_error:
|
|
||||||
# If we can't parse the JSON, fall back to the original error message
|
|
||||||
logging.debug(f"[DEBUG] Failed to parse error response: {str(json_error)}")
|
|
||||||
|
|
||||||
logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})")
|
|
||||||
if hasattr(e, "response") and e.response.content:
|
|
||||||
logging.debug(f"[DEBUG] Response content: {e.response.content}")
|
|
||||||
if status_code == 401:
|
|
||||||
error_message = "Unauthorized: Please login first to use this node."
|
|
||||||
if status_code == 402:
|
|
||||||
error_message = "Payment Required: Please add credits to your account to use this node."
|
|
||||||
if status_code == 409:
|
|
||||||
error_message = "There is a problem with your account. Please contact support@comfy.org. "
|
|
||||||
if status_code == 429:
|
|
||||||
error_message = "Rate Limit Exceeded: Please try again later."
|
|
||||||
raise Exception(error_message)
|
|
||||||
|
|
||||||
# Parse and return JSON response
|
|
||||||
if response.content:
|
|
||||||
return response.json()
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def check_auth_token(self, auth_token):
|
|
||||||
"""Verify that an auth token is present."""
|
|
||||||
if auth_token is None:
|
|
||||||
raise Exception("Unauthorized: Please login first to use this node.")
|
|
||||||
return auth_token
|
|
||||||
|
|
||||||
|
|
||||||
class ApiEndpoint(Generic[T, R]):
|
|
||||||
"""Defines an API endpoint with its request and response types"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
method: HttpMethod,
|
|
||||||
request_model: Type[T],
|
|
||||||
response_model: Type[R],
|
|
||||||
query_params: Optional[Dict[str, Any]] = None,
|
|
||||||
):
|
|
||||||
"""Initialize an API endpoint definition.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: The URL path for this endpoint, can include placeholders like {id}
|
|
||||||
method: The HTTP method to use (GET, POST, etc.)
|
|
||||||
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
|
|
||||||
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
|
|
||||||
query_params: Optional dictionary of query parameters to include in the request
|
|
||||||
"""
|
|
||||||
self.path = path
|
|
||||||
self.method = method
|
|
||||||
self.request_model = request_model
|
|
||||||
self.response_model = response_model
|
|
||||||
self.query_params = query_params or {}
|
|
||||||
|
|
||||||
|
|
||||||
class SynchronousOperation(Generic[T, R]):
|
|
||||||
"""
|
|
||||||
Represents a single synchronous API operation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
endpoint: ApiEndpoint[T, R],
|
|
||||||
request: T,
|
|
||||||
files: Optional[Dict[str, Any]] = None,
|
|
||||||
api_base: str = "https://api.comfy.org",
|
|
||||||
auth_token: Optional[str] = None,
|
|
||||||
timeout: float = 604800.0,
|
|
||||||
verify_ssl: bool = True,
|
|
||||||
):
|
|
||||||
self.endpoint = endpoint
|
|
||||||
self.request = request
|
|
||||||
self.response = None
|
|
||||||
self.error = None
|
|
||||||
self.api_base = api_base
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self.timeout = timeout
|
|
||||||
self.verify_ssl = verify_ssl
|
|
||||||
self.files = files
|
|
||||||
def execute(self, client: Optional[ApiClient] = None) -> R:
|
|
||||||
"""Execute the API operation using the provided client or create one"""
|
|
||||||
try:
|
|
||||||
# Create client if not provided
|
|
||||||
if client is None:
|
|
||||||
if self.api_base is None:
|
|
||||||
raise ValueError("Either client or api_base must be provided")
|
|
||||||
client = ApiClient(
|
|
||||||
base_url=self.api_base,
|
|
||||||
api_key=self.auth_token,
|
|
||||||
timeout=self.timeout,
|
|
||||||
verify_ssl=self.verify_ssl,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert request model to dict, but use None for EmptyRequest
|
|
||||||
request_dict = None if isinstance(self.request, EmptyRequest) else self.request.model_dump(exclude_none=True)
|
|
||||||
|
|
||||||
# Debug log for request
|
|
||||||
logging.debug(f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}")
|
|
||||||
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
|
|
||||||
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
|
|
||||||
|
|
||||||
# Make the request
|
|
||||||
resp = client.request(
|
|
||||||
method=self.endpoint.method.value,
|
|
||||||
path=self.endpoint.path,
|
|
||||||
json=request_dict,
|
|
||||||
params=self.endpoint.query_params,
|
|
||||||
files=self.files,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Debug log for response
|
|
||||||
logging.debug("=" * 50)
|
|
||||||
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
|
||||||
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
|
||||||
logging.debug(f"[DEBUG] Response Body: {json.dumps(resp, indent=2)}")
|
|
||||||
logging.debug("=" * 50)
|
|
||||||
|
|
||||||
# Parse and return the response
|
|
||||||
return self._parse_response(resp)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.debug(f"[DEBUG] API Exception: {str(e)}")
|
|
||||||
raise Exception(str(e))
|
|
||||||
|
|
||||||
def _parse_response(self, resp):
|
|
||||||
"""Parse response data - can be overridden by subclasses"""
|
|
||||||
# The response is already the complete object, don't extract just the "data" field
|
|
||||||
# as that would lose the outer structure (created timestamp, etc.)
|
|
||||||
|
|
||||||
# Parse response using the provided model
|
|
||||||
self.response = self.endpoint.response_model.model_validate(resp)
|
|
||||||
logging.debug(f"[DEBUG] Parsed Response: {self.response}")
|
|
||||||
return self.response
|
|
||||||
@@ -1,442 +0,0 @@
|
|||||||
import io
|
|
||||||
from inspect import cleandoc
|
|
||||||
|
|
||||||
from comfy.utils import common_upscale
|
|
||||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
|
||||||
from comfy_api_nodes.apis import (
|
|
||||||
OpenAIImageGenerationRequest,
|
|
||||||
OpenAIImageEditRequest,
|
|
||||||
OpenAIImageGenerationResponse
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apis.client import ApiEndpoint, HttpMethod, SynchronousOperation
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import requests
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
import base64
|
|
||||||
|
|
||||||
def downscale_input(image):
|
|
||||||
samples = image.movedim(-1,1)
|
|
||||||
#downscaling input images to roughly the same size as the outputs
|
|
||||||
total = int(1536 * 1024)
|
|
||||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
|
||||||
if scale_by >= 1:
|
|
||||||
return image
|
|
||||||
width = round(samples.shape[3] * scale_by)
|
|
||||||
height = round(samples.shape[2] * scale_by)
|
|
||||||
|
|
||||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
|
||||||
s = s.movedim(1,-1)
|
|
||||||
return s
|
|
||||||
|
|
||||||
def validate_and_cast_response(response):
|
|
||||||
# validate raw JSON response
|
|
||||||
data = response.data
|
|
||||||
if not data or len(data) == 0:
|
|
||||||
raise Exception("No images returned from API endpoint")
|
|
||||||
|
|
||||||
# Initialize list to store image tensors
|
|
||||||
image_tensors = []
|
|
||||||
|
|
||||||
# Process each image in the data array
|
|
||||||
for image_data in data:
|
|
||||||
image_url = image_data.url
|
|
||||||
b64_data = image_data.b64_json
|
|
||||||
|
|
||||||
if not image_url and not b64_data:
|
|
||||||
raise Exception("No image was generated in the response")
|
|
||||||
|
|
||||||
if b64_data:
|
|
||||||
img_data = base64.b64decode(b64_data)
|
|
||||||
img = Image.open(io.BytesIO(img_data))
|
|
||||||
|
|
||||||
elif image_url:
|
|
||||||
img_response = requests.get(image_url)
|
|
||||||
if img_response.status_code != 200:
|
|
||||||
raise Exception("Failed to download the image")
|
|
||||||
img = Image.open(io.BytesIO(img_response.content))
|
|
||||||
|
|
||||||
img = img.convert("RGBA")
|
|
||||||
|
|
||||||
# Convert to numpy array, normalize to float32 between 0 and 1
|
|
||||||
img_array = np.array(img).astype(np.float32) / 255.0
|
|
||||||
img_tensor = torch.from_numpy(img_array)
|
|
||||||
|
|
||||||
# Add to list of tensors
|
|
||||||
image_tensors.append(img_tensor)
|
|
||||||
|
|
||||||
return torch.stack(image_tensors, dim=0)
|
|
||||||
|
|
||||||
class OpenAIDalle2(ComfyNodeABC):
|
|
||||||
"""
|
|
||||||
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
|
|
||||||
|
|
||||||
Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived,
|
|
||||||
so download or cache results if you need to keep them.
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"prompt": (IO.STRING, {
|
|
||||||
"multiline": True,
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Text prompt for DALL·E",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"seed": (IO.INT, {
|
|
||||||
"default": 0,
|
|
||||||
"min": 0,
|
|
||||||
"max": 2**31-1,
|
|
||||||
"step": 1,
|
|
||||||
"display": "number",
|
|
||||||
"tooltip": "not implemented yet in backend",
|
|
||||||
}),
|
|
||||||
"size": (IO.COMBO, {
|
|
||||||
"options": ["256x256", "512x512", "1024x1024"],
|
|
||||||
"default": "1024x1024",
|
|
||||||
"tooltip": "Image size",
|
|
||||||
}),
|
|
||||||
"n": (IO.INT, {
|
|
||||||
"default": 1,
|
|
||||||
"min": 1,
|
|
||||||
"max": 8,
|
|
||||||
"step": 1,
|
|
||||||
"display": "number",
|
|
||||||
"tooltip": "How many images to generate",
|
|
||||||
}),
|
|
||||||
"image": (IO.IMAGE, {
|
|
||||||
"default": None,
|
|
||||||
"tooltip": "Optional reference image for image editing.",
|
|
||||||
}),
|
|
||||||
"mask": (IO.MASK, {
|
|
||||||
"default": None,
|
|
||||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
"hidden": {
|
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (IO.IMAGE,)
|
|
||||||
FUNCTION = "api_call"
|
|
||||||
CATEGORY = "api node"
|
|
||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
|
||||||
API_NODE = True
|
|
||||||
|
|
||||||
def api_call(self, prompt, seed=0, image=None, mask=None, n=1, size="1024x1024", auth_token=None):
|
|
||||||
model = "dall-e-2"
|
|
||||||
path = "/proxy/openai/images/generations"
|
|
||||||
request_class = OpenAIImageGenerationRequest
|
|
||||||
img_binary = None
|
|
||||||
|
|
||||||
if image is not None and mask is not None:
|
|
||||||
path = "/proxy/openai/images/edits"
|
|
||||||
request_class = OpenAIImageEditRequest
|
|
||||||
|
|
||||||
input_tensor = image.squeeze().cpu()
|
|
||||||
height, width, channels = input_tensor.shape
|
|
||||||
rgba_tensor = torch.ones(height, width, 4, device="cpu")
|
|
||||||
rgba_tensor[:, :, :channels] = input_tensor
|
|
||||||
|
|
||||||
if mask.shape[1:] != image.shape[1:-1]:
|
|
||||||
raise Exception("Mask and Image must be the same size")
|
|
||||||
rgba_tensor[:,:,3] = (1-mask.squeeze().cpu())
|
|
||||||
|
|
||||||
rgba_tensor = downscale_input(rgba_tensor.unsqueeze(0)).squeeze()
|
|
||||||
|
|
||||||
image_np = (rgba_tensor.numpy() * 255).astype(np.uint8)
|
|
||||||
img = Image.fromarray(image_np)
|
|
||||||
img_byte_arr = io.BytesIO()
|
|
||||||
img.save(img_byte_arr, format='PNG')
|
|
||||||
img_byte_arr.seek(0)
|
|
||||||
img_binary = img_byte_arr#.getvalue()
|
|
||||||
img_binary.name = "image.png"
|
|
||||||
elif image is not None or mask is not None:
|
|
||||||
raise Exception("Dall-E 2 image editing requires an image AND a mask")
|
|
||||||
|
|
||||||
# Build the operation
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=path,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=request_class,
|
|
||||||
response_model=OpenAIImageGenerationResponse
|
|
||||||
),
|
|
||||||
request=request_class(
|
|
||||||
model=model,
|
|
||||||
prompt=prompt,
|
|
||||||
n=n,
|
|
||||||
size=size,
|
|
||||||
seed=seed,
|
|
||||||
),
|
|
||||||
files={
|
|
||||||
"image": img_binary,
|
|
||||||
} if img_binary else None,
|
|
||||||
auth_token=auth_token
|
|
||||||
)
|
|
||||||
|
|
||||||
response = operation.execute()
|
|
||||||
|
|
||||||
img_tensor = validate_and_cast_response(response)
|
|
||||||
return (img_tensor,)
|
|
||||||
|
|
||||||
class OpenAIDalle3(ComfyNodeABC):
|
|
||||||
"""
|
|
||||||
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
|
|
||||||
|
|
||||||
Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived,
|
|
||||||
so download or cache results if you need to keep them.
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"prompt": (IO.STRING, {
|
|
||||||
"multiline": True,
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Text prompt for DALL·E",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"seed": (IO.INT, {
|
|
||||||
"default": 0,
|
|
||||||
"min": 0,
|
|
||||||
"max": 2**31-1,
|
|
||||||
"step": 1,
|
|
||||||
"display": "number",
|
|
||||||
"tooltip": "not implemented yet in backend",
|
|
||||||
}),
|
|
||||||
"quality" : (IO.COMBO, {
|
|
||||||
"options": ["standard","hd"],
|
|
||||||
"default": "standard",
|
|
||||||
"tooltip": "Image quality",
|
|
||||||
}),
|
|
||||||
"style": (IO.COMBO, {
|
|
||||||
"options": ["natural","vivid"],
|
|
||||||
"default": "natural",
|
|
||||||
"tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.",
|
|
||||||
}),
|
|
||||||
"size": (IO.COMBO, {
|
|
||||||
"options": ["1024x1024", "1024x1792", "1792x1024"],
|
|
||||||
"default": "1024x1024",
|
|
||||||
"tooltip": "Image size",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
"hidden": {
|
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (IO.IMAGE,)
|
|
||||||
FUNCTION = "api_call"
|
|
||||||
CATEGORY = "api node"
|
|
||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
|
||||||
API_NODE = True
|
|
||||||
|
|
||||||
def api_call(self, prompt, seed=0, style="natural", quality="standard", size="1024x1024", auth_token=None):
|
|
||||||
model = "dall-e-3"
|
|
||||||
|
|
||||||
# build the operation
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/openai/images/generations",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=OpenAIImageGenerationRequest,
|
|
||||||
response_model=OpenAIImageGenerationResponse
|
|
||||||
),
|
|
||||||
request=OpenAIImageGenerationRequest(
|
|
||||||
model=model,
|
|
||||||
prompt=prompt,
|
|
||||||
quality=quality,
|
|
||||||
size=size,
|
|
||||||
style=style,
|
|
||||||
seed=seed,
|
|
||||||
),
|
|
||||||
auth_token=auth_token
|
|
||||||
)
|
|
||||||
|
|
||||||
response = operation.execute()
|
|
||||||
|
|
||||||
img_tensor = validate_and_cast_response(response)
|
|
||||||
return (img_tensor,)
|
|
||||||
|
|
||||||
class OpenAIGPTImage1(ComfyNodeABC):
|
|
||||||
"""
|
|
||||||
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
|
|
||||||
|
|
||||||
Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived,
|
|
||||||
so download or cache results if you need to keep them.
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"prompt": (IO.STRING, {
|
|
||||||
"multiline": True,
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Text prompt for GPT Image 1",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"seed": (IO.INT, {
|
|
||||||
"default": 0,
|
|
||||||
"min": 0,
|
|
||||||
"max": 2**31-1,
|
|
||||||
"step": 1,
|
|
||||||
"display": "number",
|
|
||||||
"tooltip": "not implemented yet in backend",
|
|
||||||
}),
|
|
||||||
"quality": (IO.COMBO, {
|
|
||||||
"options": ["low","medium","high"],
|
|
||||||
"default": "low",
|
|
||||||
"tooltip": "Image quality, affects cost and generation time.",
|
|
||||||
}),
|
|
||||||
"background": (IO.COMBO, {
|
|
||||||
"options": ["opaque","transparent"],
|
|
||||||
"default": "opaque",
|
|
||||||
"tooltip": "Return image with or without background",
|
|
||||||
}),
|
|
||||||
"size": (IO.COMBO, {
|
|
||||||
"options": ["auto", "1024x1024", "1024x1536", "1536x1024"],
|
|
||||||
"default": "auto",
|
|
||||||
"tooltip": "Image size",
|
|
||||||
}),
|
|
||||||
"n": (IO.INT, {
|
|
||||||
"default": 1,
|
|
||||||
"min": 1,
|
|
||||||
"max": 8,
|
|
||||||
"step": 1,
|
|
||||||
"display": "number",
|
|
||||||
"tooltip": "How many images to generate",
|
|
||||||
}),
|
|
||||||
"image": (IO.IMAGE, {
|
|
||||||
"default": None,
|
|
||||||
"tooltip": "Optional reference image for image editing.",
|
|
||||||
}),
|
|
||||||
"mask": (IO.MASK, {
|
|
||||||
"default": None,
|
|
||||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
"hidden": {
|
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (IO.IMAGE,)
|
|
||||||
FUNCTION = "api_call"
|
|
||||||
CATEGORY = "api node"
|
|
||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
|
||||||
API_NODE = True
|
|
||||||
|
|
||||||
def api_call(self, prompt, seed=0, quality="low", background="opaque", image=None, mask=None, n=1, size="1024x1024", auth_token=None):
|
|
||||||
model = "gpt-image-1"
|
|
||||||
path = "/proxy/openai/images/generations"
|
|
||||||
request_class = OpenAIImageGenerationRequest
|
|
||||||
img_binaries = []
|
|
||||||
mask_binary = None
|
|
||||||
files = []
|
|
||||||
|
|
||||||
if image is not None:
|
|
||||||
path = "/proxy/openai/images/edits"
|
|
||||||
request_class = OpenAIImageEditRequest
|
|
||||||
|
|
||||||
batch_size = image.shape[0]
|
|
||||||
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
|
||||||
single_image = image[i:i+1]
|
|
||||||
scaled_image = downscale_input(single_image).squeeze()
|
|
||||||
|
|
||||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
|
||||||
img = Image.fromarray(image_np)
|
|
||||||
img_byte_arr = io.BytesIO()
|
|
||||||
img.save(img_byte_arr, format='PNG')
|
|
||||||
img_byte_arr.seek(0)
|
|
||||||
img_binary = img_byte_arr
|
|
||||||
img_binary.name = f"image_{i}.png"
|
|
||||||
|
|
||||||
img_binaries.append(img_binary)
|
|
||||||
if batch_size == 1:
|
|
||||||
files.append(("image", img_binary))
|
|
||||||
else:
|
|
||||||
files.append(("image[]", img_binary))
|
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
if image.shape[0] != 1:
|
|
||||||
raise Exception("Cannot use a mask with multiple image")
|
|
||||||
if image is None:
|
|
||||||
raise Exception("Cannot use a mask without an input image")
|
|
||||||
if mask.shape[1:] != image.shape[1:-1]:
|
|
||||||
raise Exception("Mask and Image must be the same size")
|
|
||||||
batch, height, width = mask.shape
|
|
||||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
|
||||||
rgba_mask[:,:,3] = (1-mask.squeeze().cpu())
|
|
||||||
|
|
||||||
scaled_mask = downscale_input(rgba_mask.unsqueeze(0)).squeeze()
|
|
||||||
|
|
||||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
|
||||||
mask_img = Image.fromarray(mask_np)
|
|
||||||
mask_img_byte_arr = io.BytesIO()
|
|
||||||
mask_img.save(mask_img_byte_arr, format='PNG')
|
|
||||||
mask_img_byte_arr.seek(0)
|
|
||||||
mask_binary = mask_img_byte_arr
|
|
||||||
mask_binary.name = "mask.png"
|
|
||||||
files.append(("mask", mask_binary))
|
|
||||||
|
|
||||||
|
|
||||||
# Build the operation
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=path,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=request_class,
|
|
||||||
response_model=OpenAIImageGenerationResponse
|
|
||||||
),
|
|
||||||
request=request_class(
|
|
||||||
model=model,
|
|
||||||
prompt=prompt,
|
|
||||||
quality=quality,
|
|
||||||
background=background,
|
|
||||||
n=n,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
),
|
|
||||||
files=files if files else None,
|
|
||||||
auth_token=auth_token
|
|
||||||
)
|
|
||||||
|
|
||||||
response = operation.execute()
|
|
||||||
|
|
||||||
img_tensor = validate_and_cast_response(response)
|
|
||||||
return (img_tensor,)
|
|
||||||
|
|
||||||
|
|
||||||
# A dictionary that contains all nodes you want to export with their names
|
|
||||||
# NOTE: names should be globally unique
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"OpenAIDalle2": OpenAIDalle2,
|
|
||||||
"OpenAIDalle3": OpenAIDalle3,
|
|
||||||
"OpenAIGPTImage1": OpenAIGPTImage1,
|
|
||||||
}
|
|
||||||
|
|
||||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"OpenAIDalle2": "OpenAI DALL·E 2",
|
|
||||||
"OpenAIDalle3": "OpenAI DALL·E 3",
|
|
||||||
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
|
||||||
}
|
|
||||||
@@ -1,9 +1,6 @@
|
|||||||
from __future__ import annotations
|
|
||||||
from typing import Type, Literal
|
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
|
|
||||||
from comfy_execution.graph_utils import is_link
|
from comfy_execution.graph_utils import is_link
|
||||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
|
||||||
|
|
||||||
class DependencyCycleError(Exception):
|
class DependencyCycleError(Exception):
|
||||||
pass
|
pass
|
||||||
@@ -57,22 +54,7 @@ class DynamicPrompt:
|
|||||||
def get_original_prompt(self):
|
def get_original_prompt(self):
|
||||||
return self.original_prompt
|
return self.original_prompt
|
||||||
|
|
||||||
def get_input_info(
|
def get_input_info(class_def, input_name, valid_inputs=None):
|
||||||
class_def: Type[ComfyNodeABC],
|
|
||||||
input_name: str,
|
|
||||||
valid_inputs: InputTypeDict | None = None
|
|
||||||
) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
|
|
||||||
"""Get the input type, category, and extra info for a given input name.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
class_def: The class definition of the node.
|
|
||||||
input_name: The name of the input to get info for.
|
|
||||||
valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES().
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name.
|
|
||||||
"""
|
|
||||||
|
|
||||||
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
|
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
|
||||||
input_info = None
|
input_info = None
|
||||||
input_category = None
|
input_category = None
|
||||||
@@ -144,7 +126,7 @@ class TopologicalSort:
|
|||||||
from_node_id, from_socket = value
|
from_node_id, from_socket = value
|
||||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||||
continue
|
continue
|
||||||
_, _, input_info = self.get_input_info(unique_id, input_name)
|
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||||
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||||
node_ids.append(from_node_id)
|
node_ids.append(from_node_id)
|
||||||
|
|||||||
@@ -1,100 +0,0 @@
|
|||||||
# Code based on https://github.com/WikiChao/FreSca (MIT License)
|
|
||||||
import torch
|
|
||||||
import torch.fft as fft
|
|
||||||
|
|
||||||
|
|
||||||
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
|
|
||||||
"""
|
|
||||||
Apply frequency-dependent scaling to an image tensor using Fourier transforms.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
x: Input tensor of shape (B, C, H, W)
|
|
||||||
scale_low: Scaling factor for low-frequency components (default: 1.0)
|
|
||||||
scale_high: Scaling factor for high-frequency components (default: 1.5)
|
|
||||||
freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied.
|
|
||||||
"""
|
|
||||||
# Preserve input dtype and device
|
|
||||||
dtype, device = x.dtype, x.device
|
|
||||||
|
|
||||||
# Convert to float32 for FFT computations
|
|
||||||
x = x.to(torch.float32)
|
|
||||||
|
|
||||||
# 1) Apply FFT and shift low frequencies to center
|
|
||||||
x_freq = fft.fftn(x, dim=(-2, -1))
|
|
||||||
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
|
|
||||||
|
|
||||||
# Initialize mask with high-frequency scaling factor
|
|
||||||
mask = torch.ones(x_freq.shape, device=device) * scale_high
|
|
||||||
m = mask
|
|
||||||
for d in range(len(x_freq.shape) - 2):
|
|
||||||
dim = d + 2
|
|
||||||
cc = x_freq.shape[dim] // 2
|
|
||||||
f_c = min(freq_cutoff, cc)
|
|
||||||
m = m.narrow(dim, cc - f_c, f_c * 2)
|
|
||||||
|
|
||||||
# Apply low-frequency scaling factor to center region
|
|
||||||
m[:] = scale_low
|
|
||||||
|
|
||||||
# 3) Apply frequency-specific scaling
|
|
||||||
x_freq = x_freq * mask
|
|
||||||
|
|
||||||
# 4) Convert back to spatial domain
|
|
||||||
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
|
|
||||||
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
|
|
||||||
|
|
||||||
# 5) Restore original dtype
|
|
||||||
x_filtered = x_filtered.to(dtype)
|
|
||||||
|
|
||||||
return x_filtered
|
|
||||||
|
|
||||||
|
|
||||||
class FreSca:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("MODEL",),
|
|
||||||
"scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01,
|
|
||||||
"tooltip": "Scaling factor for low-frequency components"}),
|
|
||||||
"scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01,
|
|
||||||
"tooltip": "Scaling factor for high-frequency components"}),
|
|
||||||
"freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1,
|
|
||||||
"tooltip": "Number of frequency indices around center to consider as low-frequency"}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
|
||||||
FUNCTION = "patch"
|
|
||||||
CATEGORY = "_for_testing"
|
|
||||||
DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
|
|
||||||
def patch(self, model, scale_low, scale_high, freq_cutoff):
|
|
||||||
def custom_cfg_function(args):
|
|
||||||
cond = args["conds_out"][0]
|
|
||||||
uncond = args["conds_out"][1]
|
|
||||||
|
|
||||||
guidance = cond - uncond
|
|
||||||
filtered_guidance = Fourier_filter(
|
|
||||||
guidance,
|
|
||||||
scale_low=scale_low,
|
|
||||||
scale_high=scale_high,
|
|
||||||
freq_cutoff=freq_cutoff,
|
|
||||||
)
|
|
||||||
filtered_cond = filtered_guidance + uncond
|
|
||||||
|
|
||||||
return [filtered_cond, uncond]
|
|
||||||
|
|
||||||
m = model.clone()
|
|
||||||
m.set_model_sampler_pre_cfg_function(custom_cfg_function)
|
|
||||||
|
|
||||||
return (m,)
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"FreSca": FreSca,
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"FreSca": "FreSca",
|
|
||||||
}
|
|
||||||
@@ -26,30 +26,7 @@ class QuadrupleCLIPLoader:
|
|||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
class CLIPTextEncodeHiDream:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {
|
|
||||||
"clip": ("CLIP", ),
|
|
||||||
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
|
||||||
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
|
||||||
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
|
||||||
"llama": ("STRING", {"multiline": True, "dynamicPrompts": True})
|
|
||||||
}}
|
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
|
||||||
FUNCTION = "encode"
|
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning"
|
|
||||||
|
|
||||||
def encode(self, clip, clip_l, clip_g, t5xxl, llama):
|
|
||||||
|
|
||||||
tokens = clip.tokenize(clip_g)
|
|
||||||
tokens["l"] = clip.tokenize(clip_l)["l"]
|
|
||||||
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
|
||||||
tokens["llama"] = clip.tokenize(llama)["llama"]
|
|
||||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
|
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
|
||||||
"CLIPTextEncodeHiDream": CLIPTextEncodeHiDream,
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ class Load3D():
|
|||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE")
|
||||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info")
|
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart")
|
||||||
|
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
@@ -41,7 +41,7 @@ class Load3D():
|
|||||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||||
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
||||||
|
|
||||||
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info']
|
return output_image, output_mask, model_file, normal_image, lineart_image
|
||||||
|
|
||||||
class Load3DAnimation():
|
class Load3DAnimation():
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -59,8 +59,8 @@ class Load3DAnimation():
|
|||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE")
|
||||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info")
|
RETURN_NAMES = ("image", "mask", "mesh_path", "normal")
|
||||||
|
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
@@ -77,16 +77,13 @@ class Load3DAnimation():
|
|||||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||||
|
|
||||||
return output_image, output_mask, model_file, normal_image, image['camera_info']
|
return output_image, output_mask, model_file, normal_image
|
||||||
|
|
||||||
class Preview3D():
|
class Preview3D():
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"camera_info": ("LOAD3D_CAMERA", {})
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
@@ -98,22 +95,13 @@ class Preview3D():
|
|||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
def process(self, model_file, **kwargs):
|
def process(self, model_file, **kwargs):
|
||||||
camera_info = kwargs.get("camera_info", None)
|
return {"ui": {"model_file": [model_file]}, "result": ()}
|
||||||
|
|
||||||
return {
|
|
||||||
"ui": {
|
|
||||||
"result": [model_file, camera_info]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class Preview3DAnimation():
|
class Preview3DAnimation():
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"camera_info": ("LOAD3D_CAMERA", {})
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
@@ -125,13 +113,7 @@ class Preview3DAnimation():
|
|||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
def process(self, model_file, **kwargs):
|
def process(self, model_file, **kwargs):
|
||||||
camera_info = kwargs.get("camera_info", None)
|
return {"ui": {"model_file": [model_file]}, "result": ()}
|
||||||
|
|
||||||
return {
|
|
||||||
"ui": {
|
|
||||||
"result": [model_file, camera_info]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"Load3D": Load3D,
|
"Load3D": Load3D,
|
||||||
|
|||||||
@@ -385,7 +385,7 @@ def encode_single_frame(output_file, image_array: np.ndarray, crf):
|
|||||||
container = av.open(output_file, "w", format="mp4")
|
container = av.open(output_file, "w", format="mp4")
|
||||||
try:
|
try:
|
||||||
stream = container.add_stream(
|
stream = container.add_stream(
|
||||||
"libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
|
"h264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
|
||||||
)
|
)
|
||||||
stream.height = image_array.shape[0]
|
stream.height = image_array.shape[0]
|
||||||
stream.width = image_array.shape[1]
|
stream.width = image_array.shape[1]
|
||||||
|
|||||||
@@ -3,10 +3,7 @@ import scipy.ndimage
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import folder_paths
|
|
||||||
import random
|
|
||||||
|
|
||||||
import nodes
|
|
||||||
from nodes import MAX_RESOLUTION
|
from nodes import MAX_RESOLUTION
|
||||||
|
|
||||||
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
||||||
@@ -365,30 +362,6 @@ class ThresholdMask:
|
|||||||
mask = (mask > value).float()
|
mask = (mask > value).float()
|
||||||
return (mask,)
|
return (mask,)
|
||||||
|
|
||||||
# Mask Preview - original implement from
|
|
||||||
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
|
|
||||||
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
|
|
||||||
class MaskPreview(nodes.SaveImage):
|
|
||||||
def __init__(self):
|
|
||||||
self.output_dir = folder_paths.get_temp_directory()
|
|
||||||
self.type = "temp"
|
|
||||||
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
|
||||||
self.compress_level = 4
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {"mask": ("MASK",), },
|
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
|
||||||
}
|
|
||||||
|
|
||||||
FUNCTION = "execute"
|
|
||||||
CATEGORY = "mask"
|
|
||||||
|
|
||||||
def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
|
||||||
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
|
||||||
return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"LatentCompositeMasked": LatentCompositeMasked,
|
"LatentCompositeMasked": LatentCompositeMasked,
|
||||||
@@ -403,7 +376,6 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"FeatherMask": FeatherMask,
|
"FeatherMask": FeatherMask,
|
||||||
"GrowMask": GrowMask,
|
"GrowMask": GrowMask,
|
||||||
"ThresholdMask": ThresholdMask,
|
"ThresholdMask": ThresholdMask,
|
||||||
"MaskPreview": MaskPreview
|
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
# Primitive nodes that are evaluated at backend.
|
# Primitive nodes that are evaluated at backend.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
|
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
|
||||||
|
|
||||||
|
|
||||||
@@ -25,7 +23,7 @@ class Int(ComfyNodeABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
return {
|
return {
|
||||||
"required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})},
|
"required": {"value": (IO.INT, {"control_after_generate": True})},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = (IO.INT,)
|
RETURN_TYPES = (IO.INT,)
|
||||||
@@ -40,7 +38,7 @@ class Float(ComfyNodeABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
return {
|
return {
|
||||||
"required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})},
|
"required": {"value": (IO.FLOAT, {})},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = (IO.FLOAT,)
|
RETURN_TYPES = (IO.FLOAT,)
|
||||||
|
|||||||
@@ -50,15 +50,13 @@ class SaveWEBM:
|
|||||||
for x in extra_pnginfo:
|
for x in extra_pnginfo:
|
||||||
container.metadata[x] = json.dumps(extra_pnginfo[x])
|
container.metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
|
codec_map = {"vp9": "libvpx-vp9", "av1": "libaom-av1"}
|
||||||
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
|
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
|
||||||
stream.width = images.shape[-2]
|
stream.width = images.shape[-2]
|
||||||
stream.height = images.shape[-3]
|
stream.height = images.shape[-3]
|
||||||
stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p"
|
stream.pix_fmt = "yuv420p"
|
||||||
stream.bit_rate = 0
|
stream.bit_rate = 0
|
||||||
stream.options = {'crf': str(crf)}
|
stream.options = {'crf': str(crf)}
|
||||||
if codec == "av1":
|
|
||||||
stream.options["preset"] = "6"
|
|
||||||
|
|
||||||
for frame in images:
|
for frame in images:
|
||||||
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
|
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
import comfy.clip_vision
|
|
||||||
|
|
||||||
|
|
||||||
class WanImageToVideo:
|
class WanImageToVideo:
|
||||||
@@ -100,72 +99,6 @@ class WanFunControlToVideo:
|
|||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, negative, out_latent)
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
class WanFirstLastFrameToVideo:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {"positive": ("CONDITIONING", ),
|
|
||||||
"negative": ("CONDITIONING", ),
|
|
||||||
"vae": ("VAE", ),
|
|
||||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
|
||||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
|
||||||
},
|
|
||||||
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ),
|
|
||||||
"clip_vision_end_image": ("CLIP_VISION_OUTPUT", ),
|
|
||||||
"start_image": ("IMAGE", ),
|
|
||||||
"end_image": ("IMAGE", ),
|
|
||||||
}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
|
||||||
RETURN_NAMES = ("positive", "negative", "latent")
|
|
||||||
FUNCTION = "encode"
|
|
||||||
|
|
||||||
CATEGORY = "conditioning/video_models"
|
|
||||||
|
|
||||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None):
|
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
|
||||||
if start_image is not None:
|
|
||||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
|
||||||
if end_image is not None:
|
|
||||||
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
|
||||||
|
|
||||||
image = torch.ones((length, height, width, 3)) * 0.5
|
|
||||||
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
|
||||||
|
|
||||||
if start_image is not None:
|
|
||||||
image[:start_image.shape[0]] = start_image
|
|
||||||
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
|
||||||
|
|
||||||
if end_image is not None:
|
|
||||||
image[-end_image.shape[0]:] = end_image
|
|
||||||
mask[:, :, -end_image.shape[0]:] = 0.0
|
|
||||||
|
|
||||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
|
||||||
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
|
||||||
|
|
||||||
if clip_vision_start_image is not None:
|
|
||||||
clip_vision_output = clip_vision_start_image
|
|
||||||
|
|
||||||
if clip_vision_end_image is not None:
|
|
||||||
if clip_vision_output is not None:
|
|
||||||
states = torch.cat([clip_vision_output.penultimate_hidden_states, clip_vision_end_image.penultimate_hidden_states], dim=-2)
|
|
||||||
clip_vision_output = comfy.clip_vision.Output()
|
|
||||||
clip_vision_output.penultimate_hidden_states = states
|
|
||||||
else:
|
|
||||||
clip_vision_output = clip_vision_end_image
|
|
||||||
|
|
||||||
if clip_vision_output is not None:
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
|
||||||
|
|
||||||
out_latent = {}
|
|
||||||
out_latent["samples"] = latent
|
|
||||||
return (positive, negative, out_latent)
|
|
||||||
|
|
||||||
|
|
||||||
class WanFunInpaintToVideo:
|
class WanFunInpaintToVideo:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@@ -189,120 +122,38 @@ class WanFunInpaintToVideo:
|
|||||||
CATEGORY = "conditioning/video_models"
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
|
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
|
||||||
flfv = WanFirstLastFrameToVideo()
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
if start_image is not None:
|
||||||
|
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
if end_image is not None:
|
||||||
|
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
|
||||||
|
image = torch.ones((length, height, width, 3)) * 0.5
|
||||||
|
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
||||||
|
|
||||||
class WanVaceToVideo:
|
if start_image is not None:
|
||||||
@classmethod
|
image[:start_image.shape[0]] = start_image
|
||||||
def INPUT_TYPES(s):
|
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
||||||
return {"required": {"positive": ("CONDITIONING", ),
|
|
||||||
"negative": ("CONDITIONING", ),
|
|
||||||
"vae": ("VAE", ),
|
|
||||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
|
||||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
|
||||||
},
|
|
||||||
"optional": {"control_video": ("IMAGE", ),
|
|
||||||
"control_masks": ("MASK", ),
|
|
||||||
"reference_image": ("IMAGE", ),
|
|
||||||
}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT")
|
if end_image is not None:
|
||||||
RETURN_NAMES = ("positive", "negative", "latent", "trim_latent")
|
image[-end_image.shape[0]:] = end_image
|
||||||
FUNCTION = "encode"
|
mask[:, :, -end_image.shape[0]:] = 0.0
|
||||||
|
|
||||||
CATEGORY = "conditioning/video_models"
|
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||||
|
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
if clip_vision_output is not None:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||||
|
|
||||||
def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None):
|
|
||||||
latent_length = ((length - 1) // 4) + 1
|
|
||||||
if control_video is not None:
|
|
||||||
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
|
||||||
if control_video.shape[0] < length:
|
|
||||||
control_video = torch.nn.functional.pad(control_video, (0, 0, 0, 0, 0, 0, 0, length - control_video.shape[0]), value=0.5)
|
|
||||||
else:
|
|
||||||
control_video = torch.ones((length, height, width, 3)) * 0.5
|
|
||||||
|
|
||||||
if reference_image is not None:
|
|
||||||
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
|
||||||
reference_image = vae.encode(reference_image[:, :, :, :3])
|
|
||||||
reference_image = torch.cat([reference_image, comfy.latent_formats.Wan21().process_out(torch.zeros_like(reference_image))], dim=1)
|
|
||||||
|
|
||||||
if control_masks is None:
|
|
||||||
mask = torch.ones((length, height, width, 1))
|
|
||||||
else:
|
|
||||||
mask = control_masks
|
|
||||||
if mask.ndim == 3:
|
|
||||||
mask = mask.unsqueeze(1)
|
|
||||||
mask = comfy.utils.common_upscale(mask[:length], width, height, "bilinear", "center").movedim(1, -1)
|
|
||||||
if mask.shape[0] < length:
|
|
||||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, 0, 0, length - mask.shape[0]), value=1.0)
|
|
||||||
|
|
||||||
control_video = control_video - 0.5
|
|
||||||
inactive = (control_video * (1 - mask)) + 0.5
|
|
||||||
reactive = (control_video * mask) + 0.5
|
|
||||||
|
|
||||||
inactive = vae.encode(inactive[:, :, :, :3])
|
|
||||||
reactive = vae.encode(reactive[:, :, :, :3])
|
|
||||||
control_video_latent = torch.cat((inactive, reactive), dim=1)
|
|
||||||
if reference_image is not None:
|
|
||||||
control_video_latent = torch.cat((reference_image, control_video_latent), dim=2)
|
|
||||||
|
|
||||||
vae_stride = 8
|
|
||||||
height_mask = height // vae_stride
|
|
||||||
width_mask = width // vae_stride
|
|
||||||
mask = mask.view(length, height_mask, vae_stride, width_mask, vae_stride)
|
|
||||||
mask = mask.permute(2, 4, 0, 1, 3)
|
|
||||||
mask = mask.reshape(vae_stride * vae_stride, length, height_mask, width_mask)
|
|
||||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(latent_length, height_mask, width_mask), mode='nearest-exact').squeeze(0)
|
|
||||||
|
|
||||||
trim_latent = 0
|
|
||||||
if reference_image is not None:
|
|
||||||
mask_pad = torch.zeros_like(mask[:, :reference_image.shape[2], :, :])
|
|
||||||
mask = torch.cat((mask_pad, mask), dim=1)
|
|
||||||
latent_length += reference_image.shape[2]
|
|
||||||
trim_latent = reference_image.shape[2]
|
|
||||||
|
|
||||||
mask = mask.unsqueeze(0)
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
|
||||||
|
|
||||||
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, negative, out_latent, trim_latent)
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
class TrimVideoLatent:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": { "samples": ("LATENT",),
|
|
||||||
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
|
|
||||||
}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
|
||||||
FUNCTION = "op"
|
|
||||||
|
|
||||||
CATEGORY = "latent/video"
|
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def op(self, samples, trim_amount):
|
|
||||||
samples_out = samples.copy()
|
|
||||||
|
|
||||||
s1 = samples["samples"]
|
|
||||||
samples_out["samples"] = s1[:, :, trim_amount:]
|
|
||||||
return (samples_out,)
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"WanImageToVideo": WanImageToVideo,
|
"WanImageToVideo": WanImageToVideo,
|
||||||
"WanFunControlToVideo": WanFunControlToVideo,
|
"WanFunControlToVideo": WanFunControlToVideo,
|
||||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||||
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
|
||||||
"WanVaceToVideo": WanVaceToVideo,
|
|
||||||
"TrimVideoLatent": TrimVideoLatent,
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.30"
|
__version__ = "0.3.28"
|
||||||
|
|||||||
33
execution.py
33
execution.py
@@ -111,7 +111,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||||
def mark_missing():
|
def mark_missing():
|
||||||
missing_keys[x] = True
|
missing_keys[x] = True
|
||||||
input_data_all[x] = (None,)
|
input_data_all[x] = (None,)
|
||||||
@@ -144,8 +144,6 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||||
if h[x] == "UNIQUE_ID":
|
if h[x] == "UNIQUE_ID":
|
||||||
input_data_all[x] = [unique_id]
|
input_data_all[x] = [unique_id]
|
||||||
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
|
||||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
|
||||||
return input_data_all, missing_keys
|
return input_data_all, missing_keys
|
||||||
|
|
||||||
map_node_over_list = None #Don't hook this please
|
map_node_over_list = None #Don't hook this please
|
||||||
@@ -576,7 +574,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
received_types = {}
|
received_types = {}
|
||||||
|
|
||||||
for x in valid_inputs:
|
for x in valid_inputs:
|
||||||
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
||||||
assert extra_info is not None
|
assert extra_info is not None
|
||||||
if x not in inputs:
|
if x not in inputs:
|
||||||
if input_category == "required":
|
if input_category == "required":
|
||||||
@@ -592,7 +590,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
val = inputs[x]
|
val = inputs[x]
|
||||||
info = (input_type, extra_info)
|
info = (type_input, extra_info)
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
if len(val) != 2:
|
if len(val) != 2:
|
||||||
error = {
|
error = {
|
||||||
@@ -613,8 +611,8 @@ def validate_inputs(prompt, item, validated):
|
|||||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||||
received_type = r[val[1]]
|
received_type = r[val[1]]
|
||||||
received_types[x] = received_type
|
received_types[x] = received_type
|
||||||
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type):
|
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
|
||||||
details = f"{x}, received_type({received_type}) mismatch input_type({input_type})"
|
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
|
||||||
error = {
|
error = {
|
||||||
"type": "return_type_mismatch",
|
"type": "return_type_mismatch",
|
||||||
"message": "Return type mismatch between linked nodes",
|
"message": "Return type mismatch between linked nodes",
|
||||||
@@ -662,22 +660,22 @@ def validate_inputs(prompt, item, validated):
|
|||||||
val = val["__value__"]
|
val = val["__value__"]
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
|
|
||||||
if input_type == "INT":
|
if type_input == "INT":
|
||||||
val = int(val)
|
val = int(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if input_type == "FLOAT":
|
if type_input == "FLOAT":
|
||||||
val = float(val)
|
val = float(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if input_type == "STRING":
|
if type_input == "STRING":
|
||||||
val = str(val)
|
val = str(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if input_type == "BOOLEAN":
|
if type_input == "BOOLEAN":
|
||||||
val = bool(val)
|
val = bool(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
error = {
|
error = {
|
||||||
"type": "invalid_input_type",
|
"type": "invalid_input_type",
|
||||||
"message": f"Failed to convert an input value to a {input_type} value",
|
"message": f"Failed to convert an input value to a {type_input} value",
|
||||||
"details": f"{x}, {val}, {ex}",
|
"details": f"{x}, {val}, {ex}",
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x,
|
"input_name": x,
|
||||||
@@ -717,19 +715,18 @@ def validate_inputs(prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(input_type, list):
|
if isinstance(type_input, list):
|
||||||
combo_options = input_type
|
if val not in type_input:
|
||||||
if val not in combo_options:
|
|
||||||
input_config = info
|
input_config = info
|
||||||
list_info = ""
|
list_info = ""
|
||||||
|
|
||||||
# Don't send back gigantic lists like if they're lots of
|
# Don't send back gigantic lists like if they're lots of
|
||||||
# scanned model filepaths
|
# scanned model filepaths
|
||||||
if len(combo_options) > 20:
|
if len(type_input) > 20:
|
||||||
list_info = f"(list of length {len(combo_options)})"
|
list_info = f"(list of length {len(type_input)})"
|
||||||
input_config = None
|
input_config = None
|
||||||
else:
|
else:
|
||||||
list_info = str(combo_options)
|
list_info = str(type_input)
|
||||||
|
|
||||||
error = {
|
error = {
|
||||||
"type": "value_not_in_list",
|
"type": "value_not_in_list",
|
||||||
|
|||||||
51
nodes.py
51
nodes.py
@@ -917,7 +917,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@@ -927,10 +927,29 @@ class CLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5"
|
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl"
|
||||||
|
|
||||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
if type == "stable_cascade":
|
||||||
|
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
||||||
|
elif type == "sd3":
|
||||||
|
clip_type = comfy.sd.CLIPType.SD3
|
||||||
|
elif type == "stable_audio":
|
||||||
|
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
|
||||||
|
elif type == "mochi":
|
||||||
|
clip_type = comfy.sd.CLIPType.MOCHI
|
||||||
|
elif type == "ltxv":
|
||||||
|
clip_type = comfy.sd.CLIPType.LTXV
|
||||||
|
elif type == "pixart":
|
||||||
|
clip_type = comfy.sd.CLIPType.PIXART
|
||||||
|
elif type == "cosmos":
|
||||||
|
clip_type = comfy.sd.CLIPType.COSMOS
|
||||||
|
elif type == "lumina2":
|
||||||
|
clip_type = comfy.sd.CLIPType.LUMINA2
|
||||||
|
elif type == "wan":
|
||||||
|
clip_type = comfy.sd.CLIPType.WAN
|
||||||
|
else:
|
||||||
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
|
|
||||||
model_options = {}
|
model_options = {}
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
@@ -945,7 +964,7 @@ class DualCLIPLoader:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ),
|
"type": (["sdxl", "sd3", "flux", "hunyuan_video"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@@ -955,13 +974,19 @@ class DualCLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama"
|
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
|
||||||
|
|
||||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
|
if type == "sdxl":
|
||||||
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
|
elif type == "sd3":
|
||||||
|
clip_type = comfy.sd.CLIPType.SD3
|
||||||
|
elif type == "flux":
|
||||||
|
clip_type = comfy.sd.CLIPType.FLUX
|
||||||
|
elif type == "hunyuan_video":
|
||||||
|
clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO
|
||||||
|
|
||||||
model_options = {}
|
model_options = {}
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
@@ -2256,13 +2281,7 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_primitive.py",
|
"nodes_primitive.py",
|
||||||
"nodes_cfg.py",
|
"nodes_cfg.py",
|
||||||
"nodes_optimalsteps.py",
|
"nodes_optimalsteps.py",
|
||||||
"nodes_hidream.py",
|
"nodes_hidream.py"
|
||||||
"nodes_fresca.py",
|
|
||||||
]
|
|
||||||
|
|
||||||
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
|
|
||||||
api_nodes_files = [
|
|
||||||
"nodes_api.py",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
@@ -2270,10 +2289,6 @@ def init_builtin_extra_nodes():
|
|||||||
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
||||||
import_failed.append(node_file)
|
import_failed.append(node_file)
|
||||||
|
|
||||||
for node_file in api_nodes_files:
|
|
||||||
if not load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
|
|
||||||
import_failed.append(node_file)
|
|
||||||
|
|
||||||
return import_failed
|
return import_failed
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.30"
|
version = "0.3.28"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
comfyui-frontend-package==1.17.11
|
comfyui-frontend-package==1.15.13
|
||||||
comfyui-workflow-templates==0.1.3
|
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
@@ -22,5 +21,4 @@ psutil
|
|||||||
kornia>=0.7.1
|
kornia>=0.7.1
|
||||||
spandrel
|
spandrel
|
||||||
soundfile
|
soundfile
|
||||||
av>=14.1.0
|
av
|
||||||
pydantic~=2.0
|
|
||||||
|
|||||||
@@ -580,9 +580,6 @@ class PromptServer():
|
|||||||
info['deprecated'] = True
|
info['deprecated'] = True
|
||||||
if getattr(obj_class, "EXPERIMENTAL", False):
|
if getattr(obj_class, "EXPERIMENTAL", False):
|
||||||
info['experimental'] = True
|
info['experimental'] = True
|
||||||
|
|
||||||
if hasattr(obj_class, 'API_NODE'):
|
|
||||||
info['api_node'] = obj_class.API_NODE
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
@@ -739,12 +736,6 @@ class PromptServer():
|
|||||||
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||||||
self.app.add_routes([web.static('/extensions/' + name, dir)])
|
self.app.add_routes([web.static('/extensions/' + name, dir)])
|
||||||
|
|
||||||
workflow_templates_path = FrontendManager.templates_path()
|
|
||||||
if workflow_templates_path:
|
|
||||||
self.app.add_routes([
|
|
||||||
web.static('/templates', workflow_templates_path)
|
|
||||||
])
|
|
||||||
|
|
||||||
self.app.add_routes([
|
self.app.add_routes([
|
||||||
web.static('/', self.web_root),
|
web.static('/', self.web_root),
|
||||||
])
|
])
|
||||||
|
|||||||
Reference in New Issue
Block a user