Compare commits

..

3 Commits

Author SHA1 Message Date
Chenlei Hu
522d923948 nit 2025-03-25 16:47:52 -04:00
Chenlei Hu
c05c9b552b nit 2025-03-25 16:47:42 -04:00
Chenlei Hu
27598702e9 [Type] Annotate graph.get_input_info 2025-03-25 16:44:55 -04:00
88 changed files with 565 additions and 4995 deletions

View File

@@ -36,7 +36,7 @@ jobs:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.git_tag }}
fetch-depth: 150
fetch-depth: 0
persist-credentials: false
- uses: actions/cache/restore@v4
id: cache
@@ -70,7 +70,7 @@ jobs:
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable
mv python_embeded ComfyUI_windows_portable
@@ -85,7 +85,7 @@ jobs:
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
cd ComfyUI_windows_portable

View File

@@ -1,47 +0,0 @@
name: Generate Pydantic Stubs from api.comfy.org
on:
schedule:
- cron: '0 0 * * 1'
workflow_dispatch:
jobs:
generate-models:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install 'datamodel-code-generator[http]'
- name: Generate API models
run: |
datamodel-codegen --use-subclass-enum --url https://api.comfy.org/openapi --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
- name: Check for changes
id: git-check
run: |
git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT
- name: Create Pull Request
if: steps.git-check.outputs.changes == 'true'
uses: peter-evans/create-pull-request@v5
with:
commit-message: 'chore: update API models from OpenAPI spec'
title: 'Update API models from api.comfy.org'
body: |
This PR updates the API models based on the latest api.comfy.org OpenAPI specification.
Generated automatically by the a Github workflow.
branch: update-api-stubs
delete-branch: true
base: main

View File

@@ -56,7 +56,7 @@ jobs:
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable_nightly_pytorch
mv python_embeded ComfyUI_windows_portable_nightly_pytorch

View File

@@ -50,7 +50,7 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 150
fetch-depth: 0
persist-credentials: false
- shell: bash
run: |
@@ -67,7 +67,7 @@ jobs:
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable
mv python_embeded ComfyUI_windows_portable
@@ -82,7 +82,7 @@ jobs:
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
cd ComfyUI_windows_portable

View File

@@ -5,20 +5,20 @@
# Inlined the team members for now.
# Maintainers
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
# Python web server
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
# Node developers
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered

View File

@@ -62,7 +62,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
@@ -216,9 +215,9 @@ Additional discussion and help can be found [here](https://github.com/comfyanony
Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
This is the command to install pytorch nightly instead which might have performance improvements.
This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```

View File

@@ -9,14 +9,8 @@ class AppSettings():
self.user_manager = user_manager
def get_settings(self, request):
try:
file = self.user_manager.get_request_user_filepath(
request,
"comfy.settings.json"
)
except KeyError as e:
logging.error("User settings not found.")
raise web.HTTPUnauthorized() from e
request, "comfy.settings.json")
if os.path.isfile(file):
try:
with open(file) as f:

View File

@@ -184,27 +184,6 @@ comfyui-frontend-package is not installed.
)
sys.exit(-1)
@classmethod
def templates_path(cls) -> str:
try:
import comfyui_workflow_templates
return str(
importlib.resources.files(comfyui_workflow_templates) / "templates"
)
except ImportError:
logging.error(
f"""
********** ERROR ***********
comfyui-workflow-templates is not installed.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
)
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""

View File

@@ -66,7 +66,6 @@ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diff
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.")
fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
@@ -80,7 +79,6 @@ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Stor
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
@@ -102,7 +100,6 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@@ -137,9 +134,8 @@ parser.add_argument("--deterministic", action="store_true", help="Make pytorch u
class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")

View File

@@ -18,7 +18,6 @@ class Output:
setattr(self, key, item)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
@@ -111,13 +110,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
if embed_shape == 729:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif embed_shape == 1024:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
elif embed_shape == 577:
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
if "multi_modal_projector.linear_1.bias" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
else:

View File

@@ -1,13 +0,0 @@
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 512,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 16,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}

View File

@@ -1,7 +1,7 @@
"""Comfy-specific type hinting"""
from __future__ import annotations
from typing import Literal, TypedDict, Optional
from typing import Literal, TypedDict
from typing_extensions import NotRequired
from abc import ABC, abstractmethod
from enum import Enum
@@ -99,64 +99,55 @@ class InputTypeOptions(TypedDict):
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes
"""
default: NotRequired[bool | str | float | int | list | tuple]
default: bool | str | float | int | list | tuple
"""The default value of the widget"""
defaultInput: NotRequired[bool]
"""@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
- defaultInput on required inputs should be dropped.
- defaultInput on optional inputs should be replaced with forceInput.
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
"""
forceInput: NotRequired[bool]
"""Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
lazy: NotRequired[bool]
defaultInput: bool
"""Defaults to an input slot rather than a widget"""
forceInput: bool
"""`defaultInput` and also don't allow converting to a widget"""
lazy: bool
"""Declares that this input uses lazy evaluation"""
rawLink: NotRequired[bool]
rawLink: bool
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
tooltip: NotRequired[str]
tooltip: str
"""Tooltip for the input (or widget), shown on pointer hover"""
socketless: NotRequired[bool]
"""All inputs (including widgets) have an input socket to connect links. When ``true``, if there is a widget for this input, no socket will be created.
Available from frontend v1.17.5
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548
"""
# class InputTypeNumber(InputTypeOptions):
# default: float | int
min: NotRequired[float]
min: float
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
max: NotRequired[float]
max: float
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
step: NotRequired[float]
step: float
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
round: NotRequired[float]
round: float
"""Floats are rounded by this value (``FLOAT``)"""
# class InputTypeBoolean(InputTypeOptions):
# default: bool
label_on: NotRequired[str]
label_on: str
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
label_off: NotRequired[str]
label_off: str
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
# class InputTypeString(InputTypeOptions):
# default: str
multiline: NotRequired[bool]
multiline: bool
"""Use a multiline text box (``STRING``)"""
placeholder: NotRequired[str]
placeholder: str
"""Placeholder text to display in the UI when empty (``STRING``)"""
# Deprecated:
# defaultVal: str
dynamicPrompts: NotRequired[bool]
dynamicPrompts: bool
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
# class InputTypeCombo(InputTypeOptions):
image_upload: NotRequired[bool]
image_upload: bool
"""Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`."""
image_folder: NotRequired[Literal["input", "output", "temp"]]
image_folder: Literal["input", "output", "temp"]
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
"""
remote: NotRequired[RemoteInputOptions]
remote: RemoteInputOptions
"""Specifies the configuration for a remote input.
Available after ComfyUI frontend v1.9.7
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
control_after_generate: NotRequired[bool]
control_after_generate: bool
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
options: NotRequired[list[str | int | float]]
"""COMBO type only. Specifies the selectable options for the combo widget.
@@ -174,15 +165,15 @@ class InputTypeOptions(TypedDict):
class HiddenInputTypeDict(TypedDict):
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
node_id: NotRequired[Literal["UNIQUE_ID"]]
node_id: Literal["UNIQUE_ID"]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
unique_id: NotRequired[Literal["UNIQUE_ID"]]
unique_id: Literal["UNIQUE_ID"]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
prompt: NotRequired[Literal["PROMPT"]]
prompt: Literal["PROMPT"]
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
extra_pnginfo: NotRequired[Literal["EXTRA_PNGINFO"]]
extra_pnginfo: Literal["EXTRA_PNGINFO"]
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
dynprompt: NotRequired[Literal["DYNPROMPT"]]
dynprompt: Literal["DYNPROMPT"]
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
@@ -192,11 +183,11 @@ class InputTypeDict(TypedDict):
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs
"""
required: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
required: dict[str, tuple[IO, InputTypeOptions]]
"""Describes all inputs that must be connected for the node to execute."""
optional: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
optional: dict[str, tuple[IO, InputTypeOptions]]
"""Describes inputs which do not need to be connected."""
hidden: NotRequired[HiddenInputTypeDict]
hidden: HiddenInputTypeDict
"""Offers advanced functionality and server-client communication.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
@@ -229,8 +220,6 @@ class ComfyNodeABC(ABC):
"""Flags a node as experimental, informing users that it may change or not work as expected."""
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
API_NODE: Optional[bool]
"""Flags a node as an API node."""
@classmethod
@abstractmethod

View File

@@ -736,7 +736,6 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
return control
def load_controlnet(ckpt_path, model=None, model_options={}):
model_options = model_options.copy()
if "global_average_pooling" not in model_options:
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling

View File

@@ -116,7 +116,7 @@ class Dino2Embeddings(torch.nn.Module):
def forward(self, pixel_values):
x = self.patch_embeddings(pixel_values)
# TODO: mask_token?
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
return x

View File

@@ -1422,101 +1422,3 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
old_denoised = denoised
return x
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
'''
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
Arxiv: https://arxiv.org/abs/2305.14267
'''
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
h = t_next - t
h_eta = h * (eta + 1)
s = t + r * h
fac = 1 / (2 * r)
sigma_s = s.neg().exp()
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
# Step 1
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
# Step 2
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
'''
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
Arxiv: https://arxiv.org/abs/2305.14267
'''
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
h = t_next - t
h_eta = h * (eta + 1)
s_1 = t + r_1 * h
s_2 = t + r_2 * h
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
# Step 1
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
if inject_noise:
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
# Step 3
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
return x

View File

@@ -1,6 +1,5 @@
import torch
import comfy.rmsnorm
import comfy.ops
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
@@ -12,5 +11,20 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
return torch.nn.functional.pad(img, pad, mode=padding_mode)
try:
rms_norm_torch = torch.nn.functional.rms_norm
except:
rms_norm_torch = None
rms_norm = comfy.rmsnorm.rms_norm
def rms_norm(x, weight=None, eps=1e-6):
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
else:
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
if weight is None:
return r
else:
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)

View File

@@ -1,799 +0,0 @@
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import einops
from einops import repeat
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
import torch.nn.functional as F
from comfy.ldm.flux.math import apply_rope, rope
from comfy.ldm.flux.layers import LastLayer
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
import comfy.ldm.common_dit
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class EmbedND(nn.Module):
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(2)
class PatchEmbed(nn.Module):
def __init__(
self,
patch_size=2,
in_channels=4,
out_channels=1024,
dtype=None, device=None, operations=None
):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels
self.proj = operations.Linear(in_channels * patch_size * patch_size, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, latent):
latent = self.proj(latent)
return latent
class PooledEmbed(nn.Module):
def __init__(self, text_emb_dim, hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
def forward(self, pooled_embed):
return self.pooled_embedder(pooled_embed)
class TimestepEmbed(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
def forward(self, timesteps, wdtype):
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
t_emb = self.timestep_embedder(t_emb)
return t_emb
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
class HiDreamAttnProcessor_flashattn:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __call__(
self,
attn,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
*args,
**kwargs,
) -> torch.FloatTensor:
dtype = image_tokens.dtype
batch_size = image_tokens.shape[0]
query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype)
key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype)
value_i = attn.to_v(image_tokens)
inner_dim = key_i.shape[-1]
head_dim = inner_dim // attn.heads
query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
if image_tokens_masks is not None:
key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1)
if not attn.single:
query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype)
key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype)
value_t = attn.to_v_t(text_tokens)
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
num_image_tokens = query_i.shape[1]
num_text_tokens = query_t.shape[1]
query = torch.cat([query_i, query_t], dim=1)
key = torch.cat([key_i, key_t], dim=1)
value = torch.cat([value_i, value_t], dim=1)
else:
query = query_i
key = key_i
value = value_i
if query.shape[-1] == rope.shape[-3] * 2:
query, key = apply_rope(query, key, rope)
else:
query_1, query_2 = query.chunk(2, dim=-1)
key_1, key_2 = key.chunk(2, dim=-1)
query_1, key_1 = apply_rope(query_1, key_1, rope)
query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1)
hidden_states = attention(query, key, value)
if not attn.single:
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
hidden_states_i = attn.to_out(hidden_states_i)
hidden_states_t = attn.to_out_t(hidden_states_t)
return hidden_states_i, hidden_states_t
else:
hidden_states = attn.to_out(hidden_states)
return hidden_states
class HiDreamAttention(nn.Module):
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
upcast_attention: bool = False,
upcast_softmax: bool = False,
scale_qk: bool = True,
eps: float = 1e-5,
processor = None,
out_dim: int = None,
single: bool = False,
dtype=None, device=None, operations=None
):
# super(Attention, self).__init__()
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.out_dim = out_dim if out_dim is not None else query_dim
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
self.sliceable_head_dim = heads
self.single = single
linear_cls = operations.Linear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
self.to_k = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_v = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_out = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
self.q_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
self.k_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
if not single:
self.to_q_t = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
self.to_k_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_v_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
self.to_out_t = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
self.q_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
self.k_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
self.processor = processor
def forward(
self,
norm_image_tokens: torch.FloatTensor,
image_tokens_masks: torch.FloatTensor = None,
norm_text_tokens: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
) -> torch.Tensor:
return self.processor(
self,
image_tokens = norm_image_tokens,
image_tokens_masks = image_tokens_masks,
text_tokens = norm_text_tokens,
rope = rope,
)
class FeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
dtype=None, device=None, operations=None
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * (
(hidden_dim + multiple_of - 1) // multiple_of
)
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
def forward(self, x):
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MoEGate(nn.Module):
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, dtype=None, device=None, operations=None):
super().__init__()
self.top_k = num_activated_experts
self.n_routed_experts = num_routed_experts
self.scoring_func = 'softmax'
self.alpha = aux_loss_alpha
self.seq_aux = False
# topk selection algorithm
self.norm_topk_prob = False
self.gating_dim = embed_dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), dtype=dtype, device=device))
self.reset_parameters()
def reset_parameters(self) -> None:
pass
# import torch.nn.init as init
# init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
### select top-k experts
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
### norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
aux_loss = None
return topk_idx, topk_weight, aux_loss
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MOEFeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
num_routed_experts: int,
num_activated_experts: int,
dtype=None, device=None, operations=None
):
super().__init__()
self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2, dtype=dtype, device=device, operations=operations)
self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim, dtype=dtype, device=device, operations=operations) for i in range(num_routed_experts)])
self.gate = MoEGate(
embed_dim = dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
dtype=dtype, device=device, operations=operations
)
self.num_activated_experts = num_activated_experts
def forward(self, x):
wtype = x.dtype
identity = x
orig_shape = x.shape
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if True: # self.training: # TODO: check which branch performs faster
x = x.repeat_interleave(self.num_activated_experts, dim=0)
y = torch.empty_like(x, dtype=wtype)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape).to(dtype=wtype)
#y = AddAuxiliaryLoss.apply(y, aux_loss)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
y = y + self.shared_experts(identity)
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.num_activated_experts
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# for fp16 and other dtype
expert_cache = expert_cache.to(expert_out.dtype)
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
return expert_cache
class TextProjection(nn.Module):
def __init__(self, in_features, hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.linear = operations.Linear(in_features=in_features, out_features=hidden_size, bias=False, dtype=dtype, device=device)
def forward(self, caption):
hidden_states = self.linear(caption)
return hidden_states
class BlockType:
TransformerBlock = 1
SingleTransformerBlock = 2
class HiDreamImageSingleTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
dtype=None, device=None, operations=None
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device)
)
# 1. Attention
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
self.attn1 = HiDreamAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
processor = HiDreamAttnProcessor_flashattn(),
single = True,
dtype=dtype, device=device, operations=operations
)
# 3. Feed-forward
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
if num_routed_experts > 0:
self.ff_i = MOEFeedForwardSwiGLU(
dim = dim,
hidden_dim = 4 * dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
dtype=dtype, device=device, operations=operations
)
else:
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
def forward(
self,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
# 1. MM-Attention
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
attn_output_i = self.attn1(
norm_image_tokens,
image_tokens_masks,
rope = rope,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
# 2. Feed-forward
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
image_tokens = ff_output_i + image_tokens
return image_tokens
class HiDreamImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
dtype=None, device=None, operations=None
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 12 * dim, bias=True, dtype=dtype, device=device)
)
# nn.init.zeros_(self.adaLN_modulation[1].weight)
# nn.init.zeros_(self.adaLN_modulation[1].bias)
# 1. Attention
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
self.norm1_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
self.attn1 = HiDreamAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
processor = HiDreamAttnProcessor_flashattn(),
single = False,
dtype=dtype, device=device, operations=operations
)
# 3. Feed-forward
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
if num_routed_experts > 0:
self.ff_i = MOEFeedForwardSwiGLU(
dim = dim,
hidden_dim = 4 * dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
dtype=dtype, device=device, operations=operations
)
else:
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
self.norm3_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
def forward(
self,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
# 1. MM-Attention
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
attn_output_i, attn_output_t = self.attn1(
norm_image_tokens,
image_tokens_masks,
norm_text_tokens,
rope = rope,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
text_tokens = gate_msa_t * attn_output_t + text_tokens
# 2. Feed-forward
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
image_tokens = ff_output_i + image_tokens
text_tokens = ff_output_t + text_tokens
return image_tokens, text_tokens
class HiDreamImageBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
block_type: BlockType = BlockType.TransformerBlock,
dtype=None, device=None, operations=None
):
super().__init__()
block_classes = {
BlockType.TransformerBlock: HiDreamImageTransformerBlock,
BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
}
self.block = block_classes[block_type](
dim,
num_attention_heads,
attention_head_dim,
num_routed_experts,
num_activated_experts,
dtype=dtype, device=device, operations=operations
)
def forward(
self,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
) -> torch.FloatTensor:
return self.block(
image_tokens,
image_tokens_masks,
text_tokens,
adaln_input,
rope,
)
class HiDreamImageTransformer2DModel(nn.Module):
def __init__(
self,
patch_size: Optional[int] = None,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 16,
num_single_layers: int = 32,
attention_head_dim: int = 128,
num_attention_heads: int = 20,
caption_channels: List[int] = None,
text_emb_dim: int = 2048,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
axes_dims_rope: Tuple[int, int] = (32, 32),
max_resolution: Tuple[int, int] = (128, 128),
llama_layers: List[int] = None,
image_model=None,
dtype=None, device=None, operations=None
):
self.patch_size = patch_size
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.num_layers = num_layers
self.num_single_layers = num_single_layers
self.gradient_checkpointing = False
super().__init__()
self.dtype = dtype
self.out_channels = out_channels or in_channels
self.inner_dim = self.num_attention_heads * self.attention_head_dim
self.llama_layers = llama_layers
self.t_embedder = TimestepEmbed(self.inner_dim, dtype=dtype, device=device, operations=operations)
self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.x_embedder = PatchEmbed(
patch_size = patch_size,
in_channels = in_channels,
out_channels = self.inner_dim,
dtype=dtype, device=device, operations=operations
)
self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
self.double_stream_blocks = nn.ModuleList(
[
HiDreamImageBlock(
dim = self.inner_dim,
num_attention_heads = self.num_attention_heads,
attention_head_dim = self.attention_head_dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
block_type = BlockType.TransformerBlock,
dtype=dtype, device=device, operations=operations
)
for i in range(self.num_layers)
]
)
self.single_stream_blocks = nn.ModuleList(
[
HiDreamImageBlock(
dim = self.inner_dim,
num_attention_heads = self.num_attention_heads,
attention_head_dim = self.attention_head_dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
block_type = BlockType.SingleTransformerBlock,
dtype=dtype, device=device, operations=operations
)
for i in range(self.num_single_layers)
]
)
self.final_layer = LastLayer(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
caption_projection = []
for caption_channel in caption_channels:
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations))
self.caption_projection = nn.ModuleList(caption_projection)
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
def expand_timesteps(self, timesteps, batch_size, device):
if not torch.is_tensor(timesteps):
is_mps = device.type == "mps"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(batch_size)
return timesteps
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]]) -> List[torch.Tensor]:
x_arr = []
for i, img_size in enumerate(img_sizes):
pH, pW = img_size
x_arr.append(
einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
p1=self.patch_size, p2=self.patch_size)
)
x = torch.cat(x_arr, dim=0)
return x
def patchify(self, x, max_seq, img_sizes=None):
pz2 = self.patch_size * self.patch_size
if isinstance(x, torch.Tensor):
B = x.shape[0]
device = x.device
dtype = x.dtype
else:
B = len(x)
device = x[0].device
dtype = x[0].dtype
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
if img_sizes is not None:
for i, img_size in enumerate(img_sizes):
x_masks[i, 0:img_size[0] * img_size[1]] = 1
x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
elif isinstance(x, torch.Tensor):
pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size)
img_sizes = [[pH, pW]] * B
x_masks = None
else:
raise NotImplementedError
return x, x_masks, img_sizes
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
encoder_hidden_states_llama3=None,
control = None,
transformer_options = {},
) -> torch.Tensor:
bs, c, h, w = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
timesteps = t
pooled_embeds = y
T5_encoder_hidden_states = context
img_sizes = None
# spatial forward
batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype
# 0. time
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
timesteps = self.t_embedder(timesteps, hidden_states_type)
p_embedder = self.p_embedder(pooled_embeds)
adaln_input = timesteps + p_embedder
hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
if image_tokens_masks is None:
pH, pW = img_sizes[0]
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
hidden_states = self.x_embedder(hidden_states)
# T5_encoder_hidden_states = encoder_hidden_states[0]
encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0)
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
if self.caption_projection is not None:
new_encoder_hidden_states = []
for i, enc_hidden_state in enumerate(encoder_hidden_states):
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
new_encoder_hidden_states.append(enc_hidden_state)
encoder_hidden_states = new_encoder_hidden_states
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(T5_encoder_hidden_states)
txt_ids = torch.zeros(
batch_size,
encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
3,
device=img_ids.device, dtype=img_ids.dtype
)
ids = torch.cat((img_ids, txt_ids), dim=1)
rope = self.pe_embedder(ids)
# 2. Blocks
block_id = 0
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
for bid, block in enumerate(self.double_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
hidden_states, initial_encoder_hidden_states = block(
image_tokens = hidden_states,
image_tokens_masks = image_tokens_masks,
text_tokens = cur_encoder_hidden_states,
adaln_input = adaln_input,
rope = rope,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1
image_tokens_seq_len = hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
hidden_states_seq_len = hidden_states.shape[1]
if image_tokens_masks is not None:
encoder_attention_mask_ones = torch.ones(
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
)
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
for bid, block in enumerate(self.single_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
hidden_states = block(
image_tokens=hidden_states,
image_tokens_masks=image_tokens_masks,
text_tokens=None,
adaln_input=adaln_input,
rope=rope,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
output = self.final_layer(hidden_states, adaln_input)
output = self.unpatchify(output, img_sizes)
return -output[:, :, :h, :w]

View File

@@ -847,7 +847,6 @@ class SpatialTransformer(nn.Module):
if not isinstance(context, list):
context = [context] * len(self.transformer_blocks)
b, c, h, w = x.shape
transformer_options["activations_shape"] = list(x.shape)
x_in = x
x = self.norm(x)
if not self.use_linear:
@@ -963,7 +962,6 @@ class SpatialVideoTransformer(SpatialTransformer):
transformer_options={}
) -> torch.Tensor:
_, _, h, w = x.shape
transformer_options["activations_shape"] = list(x.shape)
x_in = x
spatial_context = None
if exists(context):

View File

@@ -83,7 +83,7 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context, **kwargs):
def forward(self, x, context):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@@ -116,14 +116,14 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context, context_img_len):
def forward(self, x, context):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
context_img = context[:, :context_img_len]
context = context[:, context_img_len:]
context_img = context[:, :257]
context = context[:, 257:]
# compute query, key, value
q = self.norm_q(self.q(x))
@@ -193,7 +193,6 @@ class WanAttentionBlock(nn.Module):
e,
freqs,
context,
context_img_len=257,
):
r"""
Args:
@@ -214,40 +213,12 @@ class WanAttentionBlock(nn.Module):
x = x + y * e[2]
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
x = x + self.cross_attn(self.norm3(x), context)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5]
return x
class VaceWanAttentionBlock(WanAttentionBlock):
def __init__(
self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
block_id=0,
operation_settings={}
):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
self.block_id = block_id
if block_id == 0:
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, c, x, **kwargs):
if self.block_id == 0:
c = self.before_proj(c) + x
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
return c_skip, c
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
@@ -279,7 +250,7 @@ class Head(nn.Module):
class MLPProj(torch.nn.Module):
def __init__(self, in_dim, out_dim, flf_pos_embed_token_number=None, operation_settings={}):
def __init__(self, in_dim, out_dim, operation_settings={}):
super().__init__()
self.proj = torch.nn.Sequential(
@@ -287,15 +258,7 @@ class MLPProj(torch.nn.Module):
torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
if flf_pos_embed_token_number is not None:
self.emb_pos = nn.Parameter(torch.empty((1, flf_pos_embed_token_number, in_dim), device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
else:
self.emb_pos = None
def forward(self, image_embeds):
if self.emb_pos is not None:
image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device)
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
@@ -321,7 +284,6 @@ class WanModel(torch.nn.Module):
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
image_model=None,
device=None,
dtype=None,
@@ -411,7 +373,7 @@ class WanModel(torch.nn.Module):
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
if model_type == 'i2v':
self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings)
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
else:
self.img_emb = None
@@ -423,7 +385,6 @@ class WanModel(torch.nn.Module):
clip_fea=None,
freqs=None,
transformer_options={},
**kwargs,
):
r"""
Forward pass through the diffusion model
@@ -459,12 +420,9 @@ class WanModel(torch.nn.Module):
# context
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
@@ -472,12 +430,12 @@ class WanModel(torch.nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x = block(x, e=e0, freqs=freqs, context=context)
# head
x = self.head(x, e)
@@ -486,7 +444,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
@@ -500,7 +458,7 @@ class WanModel(torch.nn.Module):
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
r"""
@@ -525,115 +483,3 @@ class WanModel(torch.nn.Module):
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
return u
class VaceWanModel(WanModel):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
def __init__(self,
model_type='vace',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
image_model=None,
vace_layers=None,
vace_in_dim=None,
device=None,
dtype=None,
operations=None,
):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
# Vace
if vace_layers is not None:
self.vace_layers = vace_layers
self.vace_in_dim = vace_in_dim
# vace blocks
self.vace_blocks = nn.ModuleList([
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, self.cross_attn_norm, self.eps, block_id=i, operation_settings=operation_settings)
for i in range(self.vace_layers)
])
self.vace_layers_mapping = {i: n for n, i in enumerate(range(0, self.num_layers, self.num_layers // self.vace_layers))}
# vace patch embeddings
self.vace_patch_embedding = operations.Conv3d(
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size, device=device, dtype=torch.float32
)
def forward_orig(
self,
x,
t,
context,
vace_context,
vace_strength=1.0,
clip_fea=None,
freqs=None,
transformer_options={},
**kwargs,
):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
c = c.flatten(2).transpose(1, 2)
# arguments
x_orig = x
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
ii = self.vace_layers_mapping.get(i, None)
if ii is not None:
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x += c_skip * vace_strength
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x

View File

@@ -20,7 +20,6 @@ from __future__ import annotations
import comfy.utils
import comfy.model_management
import comfy.model_base
import comfy.weight_adapter as weight_adapter
import logging
import torch
@@ -50,12 +49,139 @@ def load_lora(lora, to_load, log_missing=True):
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)
for adapter_cls in weight_adapter.adapters:
adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys)
if adapter is not None:
patch_dict[to_load[x]] = adapter
loaded_keys.update(adapter.loaded_keys)
continue
reshape_name = "{}.reshape_weight".format(x)
reshape = None
if reshape_name in lora.keys():
try:
reshape = lora[reshape_name].tolist()
loaded_keys.add(reshape_name)
except:
pass
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
diffusers2_lora = "{}.lora_B.weight".format(x)
diffusers3_lora = "{}.lora.up.weight".format(x)
mochi_lora = "{}.lora_B".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None
if regular_lora in lora.keys():
A_name = regular_lora
B_name = "{}.lora_down.weight".format(x)
mid_name = "{}.lora_mid.weight".format(x)
elif diffusers_lora in lora.keys():
A_name = diffusers_lora
B_name = "{}_lora.down.weight".format(x)
mid_name = None
elif diffusers2_lora in lora.keys():
A_name = diffusers2_lora
B_name = "{}.lora_A.weight".format(x)
mid_name = None
elif diffusers3_lora in lora.keys():
A_name = diffusers3_lora
B_name = "{}.lora.down.weight".format(x)
mid_name = None
elif mochi_lora in lora.keys():
A_name = mochi_lora
B_name = "{}.lora_A".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x)
mid_name = None
if A_name is not None:
mid = None
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
loaded_keys.add(A_name)
loaded_keys.add(B_name)
######## loha
hada_w1_a_name = "{}.hada_w1_a".format(x)
hada_w1_b_name = "{}.hada_w1_b".format(x)
hada_w2_a_name = "{}.hada_w2_a".format(x)
hada_w2_b_name = "{}.hada_w2_b".format(x)
hada_t1_name = "{}.hada_t1".format(x)
hada_t2_name = "{}.hada_t2".format(x)
if hada_w1_a_name in lora.keys():
hada_t1 = None
hada_t2 = None
if hada_t1_name in lora.keys():
hada_t1 = lora[hada_t1_name]
hada_t2 = lora[hada_t2_name]
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
loaded_keys.add(hada_w2_b_name)
######## lokr
lokr_w1_name = "{}.lokr_w1".format(x)
lokr_w2_name = "{}.lokr_w2".format(x)
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
lokr_t2_name = "{}.lokr_t2".format(x)
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
lokr_w1 = None
if lokr_w1_name in lora.keys():
lokr_w1 = lora[lokr_w1_name]
loaded_keys.add(lokr_w1_name)
lokr_w2 = None
if lokr_w2_name in lora.keys():
lokr_w2 = lora[lokr_w2_name]
loaded_keys.add(lokr_w2_name)
lokr_w1_a = None
if lokr_w1_a_name in lora.keys():
lokr_w1_a = lora[lokr_w1_a_name]
loaded_keys.add(lokr_w1_a_name)
lokr_w1_b = None
if lokr_w1_b_name in lora.keys():
lokr_w1_b = lora[lokr_w1_b_name]
loaded_keys.add(lokr_w1_b_name)
lokr_w2_a = None
if lokr_w2_a_name in lora.keys():
lokr_w2_a = lora[lokr_w2_a_name]
loaded_keys.add(lokr_w2_a_name)
lokr_w2_b = None
if lokr_w2_b_name in lora.keys():
lokr_w2_b = lora[lokr_w2_b_name]
loaded_keys.add(lokr_w2_b_name)
lokr_t2 = None
if lokr_t2_name in lora.keys():
lokr_t2 = lora[lokr_t2_name]
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
#glora
a1_name = "{}.a1.weight".format(x)
a2_name = "{}.a2.weight".format(x)
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
loaded_keys.add(b2_name)
w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x)
@@ -282,6 +408,26 @@ def model_lora_keys_unet(model, key_map={}):
return key_map
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + function(lora_diff).type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
"""
Pad a tensor to a new shape with zeros.
@@ -336,16 +482,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
if isinstance(v, list):
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
if isinstance(v, weight_adapter.WeightAdapterBase):
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
if output is None:
logging.warning("Calculate Weight Failed: {} {}".format(v.name, key))
else:
weight = output
if old_weight is not None:
weight = old_weight
continue
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
@@ -372,6 +508,157 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
dora_scale = v[4]
reshape = v[5]
if reshape is not None:
weight = pad_tensor_to_shape(weight, reshape)
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
dora_scale = v[5]
old_glora = False
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
rank = v[0].shape[0]
old_glora = True
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
pass
else:
old_glora = False
rank = v[1].shape[0]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
if v[4] is not None:
alpha = v[4] / rank
else:
alpha = 1.0
try:
if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
else:
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))

View File

@@ -1,5 +1,4 @@
import torch
import comfy.utils
def convert_lora_bfl_control(sd): #BFL loras for Flux
@@ -12,13 +11,7 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
return sd_out
def convert_lora_wan_fun(sd): #Wan Fun loras
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
return convert_lora_wan_fun(sd)
return sd

View File

@@ -37,7 +37,6 @@ import comfy.ldm.cosmos.model
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.model_management
import comfy.patcher_extension
@@ -993,40 +992,30 @@ class WAN21(BaseModel):
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
if extra_channels == 0:
if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]:
return None
image = kwargs.get("concat_latent_image", None)
device = kwargs["device"]
if image is None:
shape_image = list(noise.shape)
shape_image[1] = extra_channels
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
image = torch.zeros_like(noise)
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
for i in range(0, image.shape[1], 16):
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
image = self.process_latent_in(image)
image = utils.resize_to_batch_size(image, noise.shape[0])
if not self.image_to_video or extra_channels == image.shape[1]:
if not self.image_to_video:
return image
if image.shape[1] > (extra_channels - 4):
image = image[:, :(extra_channels - 4)]
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :4]
else:
if mask.shape[1] != 4:
mask = torch.mean(mask, dim=1, keepdim=True)
mask = 1.0 - mask
mask = 1.0 - torch.mean(mask, dim=1, keepdim=True)
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
if mask.shape[1] == 1:
mask = mask.repeat(1, 4, 1, 1, 1)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
@@ -1043,37 +1032,6 @@ class WAN21(BaseModel):
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
return out
class WAN21_Vace(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
noise = kwargs.get("noise", None)
noise_shape = list(noise.shape)
vace_frames = kwargs.get("vace_frames", None)
if vace_frames is None:
noise_shape[1] = 32
vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
for i in range(0, vace_frames.shape[1], 16):
vace_frames = vace_frames.clone()
vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16])
mask = kwargs.get("vace_mask", None)
if mask is None:
noise_shape[1] = 64
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
vace_strength = kwargs.get("vace_strength", 1.0)
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
@@ -1088,20 +1046,3 @@ class Hunyuan3Dv2(BaseModel):
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_llama3 = kwargs.get("conditioning_llama3", None)
if conditioning_llama3 is not None:
out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
return out

View File

@@ -317,18 +317,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["cross_attn_norm"] = True
dit_config["eps"] = 1e-6
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "vace"
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"
else:
dit_config["model_type"] = "t2v"
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
if flf_weight is not None:
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
@@ -346,25 +338,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
dit_config = {}
dit_config["image_model"] = "hidream"
dit_config["attention_head_dim"] = 128
dit_config["axes_dims_rope"] = [64, 32, 32]
dit_config["caption_channels"] = [4096, 4096]
dit_config["max_resolution"] = [128, 128]
dit_config["in_channels"] = 16
dit_config["llama_layers"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31]
dit_config["num_attention_heads"] = 20
dit_config["num_routed_experts"] = 4
dit_config["num_activated_experts"] = 2
dit_config["num_layers"] = 16
dit_config["num_single_layers"] = 32
dit_config["out_channels"] = 16
dit_config["patch_size"] = 2
dit_config["text_emb_dim"] = 2048
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None

View File

@@ -725,8 +725,6 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
return torch.float8_e4m3fn
if args.fp8_e5m2_unet:
return torch.float8_e5m2
if args.fp8_e8m0fnu_unet:
return torch.float8_e8m0fnu
fp8_dtype = None
if weight_dtype in FLOAT8_TYPES:
@@ -825,8 +823,6 @@ def text_encoder_dtype(device=None):
return torch.float8_e5m2
elif args.fp16_text_enc:
return torch.float16
elif args.bf16_text_enc:
return torch.bfloat16
elif args.fp32_text_enc:
return torch.float32
@@ -1239,8 +1235,6 @@ def soft_empty_cache(force=False):
torch.xpu.empty_cache()
elif is_ascend_npu():
torch.npu.empty_cache()
elif is_mlu():
torch.mlu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

View File

@@ -21,7 +21,6 @@ import logging
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
@@ -147,25 +146,6 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
def reset_parameters(self):
self.bias = None
return None
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
else:
weight = None
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
def reset_parameters(self):
return None
@@ -263,9 +243,6 @@ class manual_cast(disable_weight_init):
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
comfy_cast_weights = True
class RMSNorm(disable_weight_init.RMSNorm):
comfy_cast_weights = True
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True
@@ -380,25 +357,6 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
return scaled_fp8_op
CUBLAS_IS_AVAILABLE = False
try:
from cublas_ops import CublasLinear
CUBLAS_IS_AVAILABLE = True
except ImportError:
pass
if CUBLAS_IS_AVAILABLE:
class cublas_ops(disable_weight_init):
class Linear(CublasLinear, disable_weight_init.Linear):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
return super().forward(input)
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
@@ -411,15 +369,6 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
):
return fp8_ops
if (
PerformanceFeature.CublasOps in args.fast and
CUBLAS_IS_AVAILABLE and
weight_dtype == torch.float16 and
(compute_dtype == torch.float16 or compute_dtype is None)
):
logging.info("Using cublas ops")
return cublas_ops
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init

View File

@@ -48,7 +48,6 @@ def get_all_callbacks(call_type: str, transformer_options: dict, is_model_option
class WrappersMP:
OUTER_SAMPLE = "outer_sample"
PREPARE_SAMPLING = "prepare_sampling"
SAMPLER_SAMPLE = "sampler_sample"
CALC_COND_BATCH = "calc_cond_batch"
APPLY_MODEL = "apply_model"

View File

@@ -1,55 +0,0 @@
import torch
import comfy.model_management
import numbers
RMSNorm = None
try:
rms_norm_torch = torch.nn.functional.rms_norm
RMSNorm = torch.nn.RMSNorm
except:
rms_norm_torch = None
def rms_norm(x, weight=None, eps=1e-6):
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
else:
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
if weight is None:
return r
else:
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
if RMSNorm is None:
class RMSNorm(torch.nn.Module):
def __init__(
self,
normalized_shape,
eps=None,
elementwise_affine=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = torch.nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.bias = None
def forward(self, x):
return rms_norm(x, self.weight, self.eps)

View File

@@ -106,13 +106,6 @@ def cleanup_additional_models(models):
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
return executor.execute(model, noise_shape, conds, model_options=model_options)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)

View File

@@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation", "er_sde", "seeds_2", "seeds_3"]
"gradient_estimation", "er_sde"]
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

View File

@@ -41,7 +41,6 @@ import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.model_patcher
import comfy.lora
@@ -266,7 +265,6 @@ class VAE:
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.downscale_index_formula = None
self.upscale_index_formula = None
@@ -339,7 +337,6 @@ class VAE:
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.disable_offload = True
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
if "blocks.2.blocks.3.stack.5.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
@@ -518,7 +515,7 @@ class VAE:
pixel_samples = None
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
@@ -547,7 +544,7 @@ class VAE:
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
dims = samples.ndim - 2
args = {}
if tile_x is not None:
@@ -581,7 +578,7 @@ class VAE:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
@@ -615,7 +612,7 @@ class VAE:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
args = {}
if tile_x is not None:
@@ -703,7 +700,6 @@ class CLIPType(Enum):
COSMOS = 11
LUMINA2 = 12
WAN = 13
HIDREAM = 14
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -792,9 +788,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@@ -815,10 +808,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
@@ -835,18 +824,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
@@ -864,33 +845,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.HUNYUAN_VIDEO:
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
elif clip_type == CLIPType.HIDREAM:
# Detect
hidream_dualclip_classes = []
for hidream_te in clip_data:
te_model = detect_te_model(hidream_te)
hidream_dualclip_classes.append(te_model)
clip_l = TEModel.CLIP_L in hidream_dualclip_classes
clip_g = TEModel.CLIP_G in hidream_dualclip_classes
t5 = TEModel.T5_XXL in hidream_dualclip_classes
llama = TEModel.LLAMA3_8 in hidream_dualclip_classes
# Initialize t5xxl_detect and llama_detect kwargs if needed
t5_kwargs = t5xxl_detect(clip_data) if t5 else {}
llama_kwargs = llama_detect(clip_data) if llama else {}
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif len(clip_data) == 3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif len(clip_data) == 4:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), **llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
parameters = 0
for c in clip_data:

View File

@@ -82,8 +82,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
LAYERS = [
"last",
"pooled",
"hidden",
"all"
"hidden"
]
def __init__(self, device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
@@ -94,8 +93,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
if "model_name" not in model_options:
model_options = {**model_options, "model_name": "clip_l"}
if isinstance(textmodel_json_config, dict):
config = textmodel_json_config
@@ -103,10 +100,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with open(textmodel_json_config) as f:
config = json.load(f)
te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {})
for k, v in te_model_options.items():
config[k] = v
operations = model_options.get("custom_operations", None)
scaled_fp8 = None
@@ -154,9 +147,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if self.layer == "all":
pass
elif layer_idx is None or abs(layer_idx) > self.num_layers:
if layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
@@ -253,12 +244,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if self.enable_attention_masks:
attention_mask_model = attention_mask
if self.layer == "all":
intermediate_output = "all"
else:
intermediate_output = self.layer_idx
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
if self.layer == "last":
z = outputs[0].float()
@@ -461,7 +447,7 @@ class SDTokenizer:
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
self.max_length = max_length
self.min_length = min_length
self.end_token = None
@@ -659,7 +645,6 @@ class SD1ClipModel(torch.nn.Module):
self.clip = "clip_{}".format(self.clip_name)
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
model_options = {**model_options, "model_name": self.clip}
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
self.dtypes = set()

View File

@@ -9,7 +9,6 @@ class SDXLClipG(sd1_clip.SDClipModel):
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
model_options = {**model_options, "model_name": "clip_g"}
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
@@ -18,13 +17,14 @@ class SDXLClipG(sd1_clip.SDClipModel):
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data)
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
class SDXLTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
@@ -41,7 +41,8 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set([dtype])
@@ -74,7 +75,7 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data)
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -83,7 +84,6 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
model_options = {**model_options, "model_name": "clip_g"}
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)

View File

@@ -969,34 +969,12 @@ class WAN21_I2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "i2v",
"in_dim": 36,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, image_to_video=True, device=device)
return out
class WAN21_FunControl2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "i2v",
"in_dim": 48,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, image_to_video=False, device=device)
return out
class WAN21_Vace(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "vace",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
return out
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
@@ -1035,36 +1013,6 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
latent_format = latent_formats.Hunyuan3Dv2mini
class HiDream(supported_models_base.BASE):
unet_config = {
"image_model": "hidream",
}
sampling_settings = {
"shift": 3.0,
}
sampling_settings = {
}
# memory_usage_factor = 1.2 # TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HiDream(self, device=device)
return out
def clip_target(self, state_dict={}):
return None # TODO
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
models += [SVD_img2vid]

View File

@@ -11,7 +11,7 @@ class PT5XlModel(sd1_clip.SDClipModel):
class PT5XlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1, tokenizer_data=tokenizer_data)
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):

View File

@@ -22,7 +22,7 @@ class CosmosT5XXL(sd1_clip.SD1ClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, tokenizer_data=tokenizer_data)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512)
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):

View File

@@ -9,13 +9,14 @@ import os
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
@@ -34,7 +35,8 @@ class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5])

View File

@@ -18,7 +18,7 @@ class MochiT5XXL(sd1_clip.SD1ClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):

View File

@@ -1,155 +0,0 @@
from . import hunyuan_video
from . import sd3_clip
from comfy import sd1_clip
from comfy import sdxl_clip
import comfy.model_management
import torch
import logging
class HiDreamTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data)
self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids)
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)
def state_dict(self):
return {}
class HiDreamTEModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = set()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_l = None
if clip_g:
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_g = None
if t5:
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True)
self.dtypes.add(dtype_t5)
else:
self.t5xxl = None
if llama:
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
if "vocab_size" not in model_options:
model_options["vocab_size"] = 128256
self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009})
self.dtypes.add(dtype_llama)
else:
self.llama = None
logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama))
def set_clip_options(self, options):
if self.clip_l is not None:
self.clip_l.set_clip_options(options)
if self.clip_g is not None:
self.clip_g.set_clip_options(options)
if self.t5xxl is not None:
self.t5xxl.set_clip_options(options)
if self.llama is not None:
self.llama.set_clip_options(options)
def reset_clip_options(self):
if self.clip_l is not None:
self.clip_l.reset_clip_options()
if self.clip_g is not None:
self.clip_g.reset_clip_options()
if self.t5xxl is not None:
self.t5xxl.reset_clip_options()
if self.llama is not None:
self.llama.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
token_weight_pairs_llama = token_weight_pairs["llama"]
lg_out = None
pooled = None
extra = {}
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
if self.clip_l is not None:
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
else:
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
else:
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
if self.t5xxl is not None:
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
t5_out, t5_pooled = t5_output[:2]
else:
t5_out = None
if self.llama is not None:
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
ll_out, ll_pooled = ll_output[:2]
ll_out = ll_out[:, 1:]
else:
ll_out = None
if t5_out is None:
t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device())
if ll_out is None:
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
if pooled is None:
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
extra["conditioning_llama3"] = ll_out
return t5_out, pooled, extra
def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
return self.clip_g.load_sd(sd)
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
return self.t5xxl.load_sd(sd)
else:
return self.llama.load_sd(sd)
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
class HiDreamTEModel_(HiDreamTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["llama_scaled_fp8"] = llama_scaled_fp8
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HiDreamTEModel_

View File

@@ -21,31 +21,26 @@ def llama_detect(state_dict, prefix=""):
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256, pad_token=128258):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=pad_token, min_length=min_length, tokenizer_data=tokenizer_data)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length)
class LLAMAModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
if llama_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
textmodel_json_config = {}
vocab_size = model_options.get("vocab_size", None)
if vocab_size is not None:
textmodel_json_config["vocab_size"] = vocab_size
model_options = {**model_options, "model_name": "llama"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens=special_tokens, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class HunyuanVideoTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1, tokenizer_data=tokenizer_data)
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
out = {}
@@ -77,7 +72,8 @@ class HunyuanVideoClipModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
self.dtypes = set([dtype, dtype_llama])

View File

@@ -9,26 +9,24 @@ import torch
class HyditBertModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
model_options = {**model_options, "model_name": "hydit_clip"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class HyditBertTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77, tokenizer_data=tokenizer_data)
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77)
class MT5XLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
model_options = {**model_options, "model_name": "mt5xl"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class MT5XLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
#tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model")
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
@@ -37,7 +35,7 @@ class HyditTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
self.mt5xl = MT5XLTokenizer(tokenizer_data={**tokenizer_data, "spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}

View File

@@ -268,17 +268,11 @@ class Llama2_(nn.Module):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
intermediate = None
all_intermediate = None
if intermediate_output is not None:
if intermediate_output == "all":
all_intermediate = []
intermediate_output = None
elif intermediate_output < 0:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
for i, layer in enumerate(self.layers):
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
x = layer(
x=x,
attention_mask=mask,
@@ -289,12 +283,6 @@ class Llama2_(nn.Module):
intermediate = x.clone()
x = self.norm(x)
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1)
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.norm(intermediate)

View File

@@ -1,27 +1,30 @@
from comfy import sd1_clip
import os
class LongClipTokenizer_(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
class LongClipModel_(sd1_clip.SDClipModel):
def __init__(self, *args, **kwargs):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
class LongClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
def model_options_long_clip(sd, tokenizer_data, model_options):
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
if w is None:
w = sd.get("clip_g.text_model.embeddings.position_embedding.weight", None)
else:
model_name = "clip_g"
if w is None:
w = sd.get("text_model.embeddings.position_embedding.weight", None)
if w is not None:
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
model_name = "clip_g"
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
model_name = "clip_l"
else:
model_name = "clip_l"
if w is not None:
if w is not None and w.shape[0] == 248:
tokenizer_data = tokenizer_data.copy()
model_options = model_options.copy()
model_config = model_options.get("model_config", {})
model_config["max_position_embeddings"] = w.shape[0]
model_options["{}_model_config".format(model_name)] = model_config
tokenizer_data["{}_max_length".format(model_name)] = w.shape[0]
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
model_options["clip_l_class"] = LongClipModel_
return tokenizer_data, model_options

View File

@@ -6,7 +6,7 @@ import comfy.text_encoders.genmo
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) #pad to 128?
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):

View File

@@ -6,7 +6,7 @@ import comfy.text_encoders.llama
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False})
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}

View File

@@ -24,7 +24,7 @@ class PixArtT5XXL(sd1_clip.SD1ClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):

View File

@@ -11,7 +11,7 @@ class T5BaseModel(sd1_clip.SDClipModel):
class T5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data)
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):

View File

@@ -12,7 +12,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='clip_h', tokenizer_data=tokenizer_data)
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):

View File

@@ -15,7 +15,6 @@ class T5XXLModel(sd1_clip.SDClipModel):
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
model_options = {**model_options, "model_name": "t5xxl"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@@ -32,16 +31,17 @@ def t5_xxl_detect(state_dict, prefix=""):
return out
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77, max_length=99999999):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=max_length, min_length=min_length, tokenizer_data=tokenizer_data)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
@@ -61,7 +61,8 @@ class SD3ClipModel(torch.nn.Module):
super().__init__()
self.dtypes = set()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_l = None

View File

@@ -1,5 +1,4 @@
import torch
import os
class SPieceTokenizer:
@staticmethod
@@ -16,8 +15,6 @@ class SPieceTokenizer:
if isinstance(tokenizer_path, bytes):
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
else:
if not os.path.isfile(tokenizer_path):
raise ValueError("invalid tokenizer")
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
def get_vocab(self):

View File

@@ -11,7 +11,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel):
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}

View File

@@ -1,17 +0,0 @@
from .base import WeightAdapterBase
from .lora import LoRAAdapter
from .loha import LoHaAdapter
from .lokr import LoKrAdapter
from .glora import GLoRAAdapter
from .oft import OFTAdapter
from .boft import BOFTAdapter
adapters: list[type[WeightAdapterBase]] = [
LoRAAdapter,
LoHaAdapter,
LoKrAdapter,
GLoRAAdapter,
OFTAdapter,
BOFTAdapter,
]

View File

@@ -1,104 +0,0 @@
from typing import Optional
import torch
import torch.nn as nn
import comfy.model_management
class WeightAdapterBase:
name: str
loaded_keys: set[str]
weights: list[torch.Tensor]
@classmethod
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
raise NotImplementedError
def to_train(self) -> "WeightAdapterTrainBase":
raise NotImplementedError
def calculate_weight(
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
):
raise NotImplementedError
class WeightAdapterTrainBase(nn.Module):
def __init__(self):
super().__init__()
# [TODO] Collaborate with LoRA training PR #7032
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + function(lora_diff).type(weight.dtype)
wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0]
if wd_on_output_axis:
weight_norm = (
weight.reshape(weight.shape[0], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
)
else:
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
"""
Pad a tensor to a new shape with zeros.
Args:
tensor (torch.Tensor): The original tensor to be padded.
new_shape (List[int]): The desired shape of the padded tensor.
Returns:
torch.Tensor: A new tensor padded with zeros to the specified shape.
Note:
If the new shape is smaller than the original tensor in any dimension,
the original tensor will be truncated in that dimension.
"""
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
if len(new_shape) != len(tensor.shape):
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
# Create a new tensor filled with zeros
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
# Create slicing tuples for both tensors
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
# Copy the original tensor into the new tensor
padded_tensor[new_slices] = tensor[orig_slices]
return padded_tensor

View File

@@ -1,115 +0,0 @@
import logging
from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose
class BOFTAdapter(WeightAdapterBase):
name = "boft"
def __init__(self, loaded_keys, weights):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def load(
cls,
x: str,
lora: dict[str, torch.Tensor],
alpha: float,
dora_scale: torch.Tensor,
loaded_keys: set[str] = None,
) -> Optional["BOFTAdapter"]:
if loaded_keys is None:
loaded_keys = set()
blocks_name = "{}.boft_blocks".format(x)
rescale_name = "{}.rescale".format(x)
blocks = None
if blocks_name in lora.keys():
blocks = lora[blocks_name]
if blocks.ndim == 4:
loaded_keys.add(blocks_name)
rescale = None
if rescale_name in lora.keys():
rescale = lora[rescale_name]
loaded_keys.add(rescale_name)
if blocks is not None:
weights = (blocks, rescale, alpha, dora_scale)
return cls(loaded_keys, weights)
else:
return None
def calculate_weight(
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
):
v = self.weights
blocks = v[0]
rescale = v[1]
alpha = v[2]
dora_scale = v[3]
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
if rescale is not None:
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
boft_m, block_num, boft_b, *_ = blocks.shape
try:
# Get r
I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype)
# for Q = -Q^T
q = blocks - blocks.transpose(1, 2)
normed_q = q
if alpha > 0: # alpha in boft/bboft is for constraint
q_norm = torch.norm(q) + 1e-8
if q_norm > alpha:
normed_q = q * alpha / q_norm
# use float() to prevent unsupported type in .inverse()
r = (I + normed_q) @ (I - normed_q).float().inverse()
r = r.to(original_weight)
inp = org = original_weight
r_b = boft_b//2
for i in range(boft_m):
bi = r[i]
g = 2
k = 2**i * r_b
if strength != 1:
bi = bi * strength + (1-strength) * I
inp = (
inp.unflatten(-1, (-1, g, k))
.transpose(-2, -1)
.flatten(-3)
.unflatten(-1, (-1, boft_b))
)
inp = torch.einsum("b n m, b n ... -> b m ...", inp, bi)
inp = (
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
)
if rescale is not None:
inp = inp * rescale
lora_diff = inp - org
lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight

View File

@@ -1,93 +0,0 @@
import logging
from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose
class GLoRAAdapter(WeightAdapterBase):
name = "glora"
def __init__(self, loaded_keys, weights):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def load(
cls,
x: str,
lora: dict[str, torch.Tensor],
alpha: float,
dora_scale: torch.Tensor,
loaded_keys: set[str] = None,
) -> Optional["GLoRAAdapter"]:
if loaded_keys is None:
loaded_keys = set()
a1_name = "{}.a1.weight".format(x)
a2_name = "{}.a2.weight".format(x)
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
loaded_keys.add(b2_name)
return cls(loaded_keys, weights)
else:
return None
def calculate_weight(
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
):
v = self.weights
dora_scale = v[5]
old_glora = False
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
rank = v[0].shape[0]
old_glora = True
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
pass
else:
old_glora = False
rank = v[1].shape[0]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
if v[4] is not None:
alpha = v[4] / rank
else:
alpha = 1.0
try:
if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
else:
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight

View File

@@ -1,100 +0,0 @@
import logging
from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose
class LoHaAdapter(WeightAdapterBase):
name = "loha"
def __init__(self, loaded_keys, weights):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def load(
cls,
x: str,
lora: dict[str, torch.Tensor],
alpha: float,
dora_scale: torch.Tensor,
loaded_keys: set[str] = None,
) -> Optional["LoHaAdapter"]:
if loaded_keys is None:
loaded_keys = set()
hada_w1_a_name = "{}.hada_w1_a".format(x)
hada_w1_b_name = "{}.hada_w1_b".format(x)
hada_w2_a_name = "{}.hada_w2_a".format(x)
hada_w2_b_name = "{}.hada_w2_b".format(x)
hada_t1_name = "{}.hada_t1".format(x)
hada_t2_name = "{}.hada_t2".format(x)
if hada_w1_a_name in lora.keys():
hada_t1 = None
hada_t2 = None
if hada_t1_name in lora.keys():
hada_t1 = lora[hada_t1_name]
hada_t2 = lora[hada_t2_name]
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
loaded_keys.add(hada_w2_b_name)
return cls(loaded_keys, weights)
else:
return None
def calculate_weight(
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
):
v = self.weights
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight

View File

@@ -1,133 +0,0 @@
import logging
from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose
class LoKrAdapter(WeightAdapterBase):
name = "lokr"
def __init__(self, loaded_keys, weights):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def load(
cls,
x: str,
lora: dict[str, torch.Tensor],
alpha: float,
dora_scale: torch.Tensor,
loaded_keys: set[str] = None,
) -> Optional["LoKrAdapter"]:
if loaded_keys is None:
loaded_keys = set()
lokr_w1_name = "{}.lokr_w1".format(x)
lokr_w2_name = "{}.lokr_w2".format(x)
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
lokr_t2_name = "{}.lokr_t2".format(x)
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
lokr_w1 = None
if lokr_w1_name in lora.keys():
lokr_w1 = lora[lokr_w1_name]
loaded_keys.add(lokr_w1_name)
lokr_w2 = None
if lokr_w2_name in lora.keys():
lokr_w2 = lora[lokr_w2_name]
loaded_keys.add(lokr_w2_name)
lokr_w1_a = None
if lokr_w1_a_name in lora.keys():
lokr_w1_a = lora[lokr_w1_a_name]
loaded_keys.add(lokr_w1_a_name)
lokr_w1_b = None
if lokr_w1_b_name in lora.keys():
lokr_w1_b = lora[lokr_w1_b_name]
loaded_keys.add(lokr_w1_b_name)
lokr_w2_a = None
if lokr_w2_a_name in lora.keys():
lokr_w2_a = lora[lokr_w2_a_name]
loaded_keys.add(lokr_w2_a_name)
lokr_w2_b = None
if lokr_w2_b_name in lora.keys():
lokr_w2_b = lora[lokr_w2_b_name]
loaded_keys.add(lokr_w2_b_name)
lokr_t2 = None
if lokr_t2_name in lora.keys():
lokr_t2 = lora[lokr_t2_name]
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
return cls(loaded_keys, weights)
else:
return None
def calculate_weight(
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
):
v = self.weights
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight

View File

@@ -1,142 +0,0 @@
import logging
from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
class LoRAAdapter(WeightAdapterBase):
name = "lora"
def __init__(self, loaded_keys, weights):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def load(
cls,
x: str,
lora: dict[str, torch.Tensor],
alpha: float,
dora_scale: torch.Tensor,
loaded_keys: set[str] = None,
) -> Optional["LoRAAdapter"]:
if loaded_keys is None:
loaded_keys = set()
reshape_name = "{}.reshape_weight".format(x)
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
diffusers2_lora = "{}.lora_B.weight".format(x)
diffusers3_lora = "{}.lora.up.weight".format(x)
mochi_lora = "{}.lora_B".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None
if regular_lora in lora.keys():
A_name = regular_lora
B_name = "{}.lora_down.weight".format(x)
mid_name = "{}.lora_mid.weight".format(x)
elif diffusers_lora in lora.keys():
A_name = diffusers_lora
B_name = "{}_lora.down.weight".format(x)
mid_name = None
elif diffusers2_lora in lora.keys():
A_name = diffusers2_lora
B_name = "{}.lora_A.weight".format(x)
mid_name = None
elif diffusers3_lora in lora.keys():
A_name = diffusers3_lora
B_name = "{}.lora.down.weight".format(x)
mid_name = None
elif mochi_lora in lora.keys():
A_name = mochi_lora
B_name = "{}.lora_A".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name = "{}.lora_linear_layer.down.weight".format(x)
mid_name = None
if A_name is not None:
mid = None
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
reshape = None
if reshape_name in lora.keys():
try:
reshape = lora[reshape_name].tolist()
loaded_keys.add(reshape_name)
except:
pass
weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)
loaded_keys.add(A_name)
loaded_keys.add(B_name)
return cls(loaded_keys, weights)
else:
return None
def calculate_weight(
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
):
v = self.weights
mat1 = comfy.model_management.cast_to_device(
v[0], weight.device, intermediate_dtype
)
mat2 = comfy.model_management.cast_to_device(
v[1], weight.device, intermediate_dtype
)
dora_scale = v[4]
reshape = v[5]
if reshape is not None:
weight = pad_tensor_to_shape(weight, reshape)
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
# locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(
v[3], weight.device, intermediate_dtype
)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = (
torch.mm(
mat2.transpose(0, 1).flatten(start_dim=1),
mat3.transpose(0, 1).flatten(start_dim=1),
)
.reshape(final_shape)
.transpose(0, 1)
)
try:
lora_diff = torch.mm(
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(
dora_scale,
weight,
lora_diff,
alpha,
strength,
intermediate_dtype,
function,
)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight

View File

@@ -1,94 +0,0 @@
import logging
from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose
class OFTAdapter(WeightAdapterBase):
name = "oft"
def __init__(self, loaded_keys, weights):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def load(
cls,
x: str,
lora: dict[str, torch.Tensor],
alpha: float,
dora_scale: torch.Tensor,
loaded_keys: set[str] = None,
) -> Optional["OFTAdapter"]:
if loaded_keys is None:
loaded_keys = set()
blocks_name = "{}.oft_blocks".format(x)
rescale_name = "{}.rescale".format(x)
blocks = None
if blocks_name in lora.keys():
blocks = lora[blocks_name]
if blocks.ndim == 3:
loaded_keys.add(blocks_name)
rescale = None
if rescale_name in lora.keys():
rescale = lora[rescale_name]
loaded_keys.add(rescale_name)
if blocks is not None:
weights = (blocks, rescale, alpha, dora_scale)
return cls(loaded_keys, weights)
else:
return None
def calculate_weight(
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
):
v = self.weights
blocks = v[0]
rescale = v[1]
alpha = v[2]
dora_scale = v[3]
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
if rescale is not None:
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
block_num, block_size, *_ = blocks.shape
try:
# Get r
I = torch.eye(block_size, device=blocks.device, dtype=blocks.dtype)
# for Q = -Q^T
q = blocks - blocks.transpose(1, 2)
normed_q = q
if alpha > 0: # alpha in oft/boft is for constraint
q_norm = torch.norm(q) + 1e-8
if q_norm > alpha:
normed_q = q * alpha / q_norm
# use float() to prevent unsupported type in .inverse()
r = (I + normed_q) @ (I - normed_q).float().inverse()
r = r.to(original_weight)
lora_diff = torch.einsum(
"k n m, k n ... -> k m ...",
(r * strength) - strength * I,
original_weight,
)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight

View File

@@ -1,17 +0,0 @@
# generated by datamodel-codegen:
# filename: https://api.comfy.org/openapi
# timestamp: 2025-04-23T15:56:33+00:00
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel
from . import PixverseDto
class ResponseData(BaseModel):
ErrCode: Optional[int] = None
ErrMsg: Optional[str] = None
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None

View File

@@ -1,57 +0,0 @@
# generated by datamodel-codegen:
# filename: https://api.comfy.org/openapi
# timestamp: 2025-04-23T15:56:33+00:00
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel, Field, constr
class V2OpenAPII2VResp(BaseModel):
video_id: Optional[int] = Field(None, description='Video_id')
class V2OpenAPIT2VReq(BaseModel):
aspect_ratio: str = Field(
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
)
duration: int = Field(
...,
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
examples=[5],
)
model: str = Field(
..., description='Model version (only supports v3.5)', examples=['v3.5']
)
motion_mode: Optional[str] = Field(
'normal',
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
examples=['normal'],
)
negative_prompt: Optional[constr(max_length=2048)] = Field(
None, description='Negative prompt\n'
)
prompt: constr(max_length=2048) = Field(..., description='Prompt')
quality: str = Field(
...,
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
examples=['540p'],
)
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
style: Optional[str] = Field(
None,
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
examples=['anime'],
)
template_id: Optional[int] = Field(
None,
description='Template ID (template_id must be activated before use)',
examples=[302325299692608],
)
water_mark: Optional[bool] = Field(
False,
description='Watermark (true: add watermark, false: no watermark)',
examples=[False],
)

View File

@@ -1,422 +0,0 @@
# generated by datamodel-codegen:
# filename: https://api.comfy.org/openapi
# timestamp: 2025-04-23T15:56:33+00:00
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import AnyUrl, BaseModel, Field, confloat, conint
class Customer(BaseModel):
createdAt: Optional[datetime] = Field(
None, description='The date and time the user was created'
)
email: Optional[str] = Field(None, description='The email address for this user')
id: str = Field(..., description='The firebase UID of the user')
name: Optional[str] = Field(None, description='The name for this user')
updatedAt: Optional[datetime] = Field(
None, description='The date and time the user was last updated'
)
class Error(BaseModel):
details: Optional[List[str]] = Field(
None,
description='Optional detailed information about the error or hints for resolving it.',
)
message: Optional[str] = Field(
None, description='A clear and concise description of the error.'
)
class ErrorResponse(BaseModel):
error: str
message: str
class ImageRequest(BaseModel):
aspect_ratio: Optional[str] = Field(
None,
description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.",
)
color_palette: Optional[Dict[str, Any]] = Field(
None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.'
)
magic_prompt_option: Optional[str] = Field(
None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')."
)
model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')")
negative_prompt: Optional[str] = Field(
None,
description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.',
)
num_images: Optional[conint(ge=1, le=8)] = Field(
1, description='Optional. Number of images to generate (1-8). Defaults to 1.'
)
prompt: str = Field(
..., description='Required. The prompt to use to generate the image.'
)
resolution: Optional[str] = Field(
None,
description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.",
)
seed: Optional[conint(ge=0, le=2147483647)] = Field(
None, description='Optional. A number between 0 and 2147483647.'
)
style_type: Optional[str] = Field(
None,
description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.",
)
class Datum(BaseModel):
is_image_safe: Optional[bool] = Field(
None, description='Indicates whether the image is considered safe.'
)
prompt: Optional[str] = Field(
None, description='The prompt used to generate this image.'
)
resolution: Optional[str] = Field(
None, description="The resolution of the generated image (e.g., '1024x1024')."
)
seed: Optional[int] = Field(
None, description='The seed value used for this generation.'
)
style_type: Optional[str] = Field(
None,
description="The style type used for generation (e.g., 'REALISTIC', 'ANIME').",
)
url: Optional[str] = Field(None, description='URL to the generated image.')
class Code(Enum):
int_1100 = 1100
int_1101 = 1101
int_1102 = 1102
int_1103 = 1103
class Code1(Enum):
int_1000 = 1000
int_1001 = 1001
int_1002 = 1002
int_1003 = 1003
int_1004 = 1004
class AspectRatio(str, Enum):
field_16_9 = '16:9'
field_9_16 = '9:16'
field_1_1 = '1:1'
class Config(BaseModel):
horizontal: Optional[confloat(ge=-10.0, le=10.0)] = None
pan: Optional[confloat(ge=-10.0, le=10.0)] = None
roll: Optional[confloat(ge=-10.0, le=10.0)] = None
tilt: Optional[confloat(ge=-10.0, le=10.0)] = None
vertical: Optional[confloat(ge=-10.0, le=10.0)] = None
zoom: Optional[confloat(ge=-10.0, le=10.0)] = None
class Type(str, Enum):
simple = 'simple'
down_back = 'down_back'
forward_up = 'forward_up'
right_turn_forward = 'right_turn_forward'
left_turn_forward = 'left_turn_forward'
class CameraControl(BaseModel):
config: Optional[Config] = None
type: Optional[Type] = Field(None, description='Predefined camera movements type')
class Duration(str, Enum):
field_5 = 5
field_10 = 10
class Mode(str, Enum):
std = 'std'
pro = 'pro'
class TaskInfo(BaseModel):
external_task_id: Optional[str] = None
class Video(BaseModel):
duration: Optional[str] = Field(None, description='Total video duration')
id: Optional[str] = Field(None, description='Generated video ID')
url: Optional[AnyUrl] = Field(None, description='URL for generated video')
class TaskResult(BaseModel):
videos: Optional[List[Video]] = None
class TaskStatus(str, Enum):
submitted = 'submitted'
processing = 'processing'
succeed = 'succeed'
failed = 'failed'
class Data(BaseModel):
created_at: Optional[int] = Field(None, description='Task creation time')
task_id: Optional[str] = Field(None, description='Task ID')
task_info: Optional[TaskInfo] = None
task_result: Optional[TaskResult] = None
task_status: Optional[TaskStatus] = None
updated_at: Optional[int] = Field(None, description='Task update time')
class AspectRatio1(str, Enum):
field_16_9 = '16:9'
field_9_16 = '9:16'
field_1_1 = '1:1'
field_4_3 = '4:3'
field_3_4 = '3:4'
field_3_2 = '3:2'
field_2_3 = '2:3'
field_21_9 = '21:9'
class ImageReference(str, Enum):
subject = 'subject'
face = 'face'
class Image(BaseModel):
index: Optional[int] = Field(None, description='Image Number (0-9)')
url: Optional[AnyUrl] = Field(None, description='URL for generated image')
class TaskResult1(BaseModel):
images: Optional[List[Image]] = None
class Data1(BaseModel):
created_at: Optional[int] = Field(None, description='Task creation time')
task_id: Optional[str] = Field(None, description='Task ID')
task_result: Optional[TaskResult1] = None
task_status: Optional[TaskStatus] = None
task_status_msg: Optional[str] = Field(None, description='Task status information')
updated_at: Optional[int] = Field(None, description='Task update time')
class AspectRatio2(str, Enum):
field_16_9 = '16:9'
field_9_16 = '9:16'
field_1_1 = '1:1'
class CameraControl1(BaseModel):
config: Optional[Config] = None
type: Optional[Type] = Field(None, description='Predefined camera movements type')
class ModelName2(str, Enum):
kling_v1 = 'kling-v1'
kling_v1_6 = 'kling-v1-6'
class TaskResult2(BaseModel):
videos: Optional[List[Video]] = None
class Data2(BaseModel):
created_at: Optional[int] = Field(None, description='Task creation time')
task_id: Optional[str] = Field(None, description='Task ID')
task_info: Optional[TaskInfo] = None
task_result: Optional[TaskResult2] = None
task_status: Optional[TaskStatus] = None
updated_at: Optional[int] = Field(None, description='Task update time')
class Code2(Enum):
int_1200 = 1200
int_1201 = 1201
int_1202 = 1202
int_1203 = 1203
class ResourcePackType(str, Enum):
decreasing_total = 'decreasing_total'
constant_period = 'constant_period'
class Status(str, Enum):
toBeOnline = 'toBeOnline'
online = 'online'
expired = 'expired'
runOut = 'runOut'
class ResourcePackSubscribeInfo(BaseModel):
effective_time: Optional[int] = Field(
None, description='Effective time, Unix timestamp in ms'
)
invalid_time: Optional[int] = Field(
None, description='Expiration time, Unix timestamp in ms'
)
purchase_time: Optional[int] = Field(
None, description='Purchase time, Unix timestamp in ms'
)
remaining_quantity: Optional[float] = Field(
None, description='Remaining quantity (updated with a 12-hour delay)'
)
resource_pack_id: Optional[str] = Field(None, description='Resource package ID')
resource_pack_name: Optional[str] = Field(None, description='Resource package name')
resource_pack_type: Optional[ResourcePackType] = Field(
None,
description='Resource package type (decreasing_total=decreasing total, constant_period=constant periodicity)',
)
status: Optional[Status] = Field(None, description='Resource Package Status')
total_quantity: Optional[float] = Field(None, description='Total quantity')
class Background(str, Enum):
transparent = 'transparent'
opaque = 'opaque'
class Moderation(str, Enum):
low = 'low'
auto = 'auto'
class OutputFormat(str, Enum):
png = 'png'
webp = 'webp'
jpeg = 'jpeg'
class Quality(str, Enum):
low = 'low'
medium = 'medium'
high = 'high'
class OpenAIImageEditRequest(BaseModel):
background: Optional[str] = Field(
None, description='Background transparency', examples=['opaque']
)
model: str = Field(
..., description='The model to use for image editing', examples=['gpt-image-1']
)
moderation: Optional[Moderation] = Field(
None, description='Content moderation setting', examples=['auto']
)
n: Optional[int] = Field(
None, description='The number of images to generate', examples=[1]
)
output_compression: Optional[int] = Field(
None, description='Compression level for JPEG or WebP (0-100)', examples=[100]
)
output_format: Optional[OutputFormat] = Field(
None, description='Format of the output image', examples=['png']
)
prompt: str = Field(
...,
description='A text description of the desired edit',
examples=['Give the rocketship rainbow coloring'],
)
quality: Optional[str] = Field(
None, description='The quality of the edited image', examples=['low']
)
size: Optional[str] = Field(
None, description='Size of the output image', examples=['1024x1024']
)
user: Optional[str] = Field(
None,
description='A unique identifier for end-user monitoring',
examples=['user-1234'],
)
class Quality1(str, Enum):
low = 'low'
medium = 'medium'
high = 'high'
standard = 'standard'
hd = 'hd'
class ResponseFormat(str, Enum):
url = 'url'
b64_json = 'b64_json'
class Style(str, Enum):
vivid = 'vivid'
natural = 'natural'
class OpenAIImageGenerationRequest(BaseModel):
background: Optional[Background] = Field(
None, description='Background transparency', examples=['opaque']
)
model: Optional[str] = Field(
None, description='The model to use for image generation', examples=['dall-e-3']
)
moderation: Optional[Moderation] = Field(
None, description='Content moderation setting', examples=['auto']
)
n: Optional[int] = Field(
None,
description='The number of images to generate (1-10). Only 1 supported for dall-e-3.',
examples=[1],
)
output_compression: Optional[int] = Field(
None, description='Compression level for JPEG or WebP (0-100)', examples=[100]
)
output_format: Optional[OutputFormat] = Field(
None, description='Format of the output image', examples=['png']
)
prompt: str = Field(
...,
description='A text description of the desired image',
examples=['Draw a rocket in front of a blackhole in deep space'],
)
quality: Optional[Quality1] = Field(
None, description='The quality of the generated image', examples=['high']
)
response_format: Optional[ResponseFormat] = Field(
None, description='Response format of image data', examples=['b64_json']
)
size: Optional[str] = Field(
None,
description='Size of the image (e.g., 1024x1024, 1536x1024, auto)',
examples=['1024x1536'],
)
style: Optional[Style] = Field(
None, description='Style of the image (only for dall-e-3)', examples=['vivid']
)
user: Optional[str] = Field(
None,
description='A unique identifier for end-user monitoring',
examples=['user-1234'],
)
class Datum1(BaseModel):
b64_json: Optional[str] = Field(None, description='Base64 encoded image data')
revised_prompt: Optional[str] = Field(None, description='Revised prompt')
url: Optional[str] = Field(None, description='URL of the image')
class OpenAIImageGenerationResponse(BaseModel):
data: Optional[List[Datum1]] = None
class User(BaseModel):
email: Optional[str] = Field(None, description='The email address for this user.')
id: Optional[str] = Field(None, description='The unique id for this user.')
isAdmin: Optional[bool] = Field(
None, description='Indicates if the user has admin privileges.'
)
isApproved: Optional[bool] = Field(
None, description='Indicates if the user is approved.'
)
name: Optional[str] = Field(None, description='The name for this user.')

View File

@@ -1,337 +0,0 @@
import logging
"""
API Client Framework for api.comfy.org.
This module provides a flexible framework for making API requests from ComfyUI nodes.
It supports both synchronous and asynchronous API operations with proper type validation.
Key Components:
--------------
1. ApiClient - Handles HTTP requests with authentication and error handling
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
3. ApiOperation - Executes a single synchronous API operation
Usage Examples:
--------------
# Example 1: Synchronous API Operation
# ------------------------------------
# For a simple API call that returns the result immediately:
# 1. Create the API client
api_client = ApiClient(
base_url="https://api.example.com",
api_key="your_api_key_here",
timeout=30.0,
verify_ssl=True
)
# 2. Define the endpoint
user_info_endpoint = ApiEndpoint(
path="/v1/users/me",
method=HttpMethod.GET,
request_model=EmptyRequest, # No request body needed
response_model=UserProfile, # Pydantic model for the response
query_params=None
)
# 3. Create the request object
request = EmptyRequest()
# 4. Create and execute the operation
operation = ApiOperation(
endpoint=user_info_endpoint,
request=request
)
user_profile = operation.execute(client=api_client) # Returns immediately with the result
"""
from typing import (
Dict,
Type,
Optional,
Any,
TypeVar,
Generic,
)
from pydantic import BaseModel
from enum import Enum
import json
import requests
from urllib.parse import urljoin
T = TypeVar("T", bound=BaseModel)
R = TypeVar("R", bound=BaseModel)
class EmptyRequest(BaseModel):
"""Base class for empty request bodies.
For GET requests, fields will be sent as query parameters."""
pass
class HttpMethod(str, Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
PATCH = "PATCH"
class ApiClient:
"""
Client for making HTTP requests to an API with authentication and error handling.
"""
def __init__(
self,
base_url: str,
api_key: Optional[str] = None,
timeout: float = 30.0,
verify_ssl: bool = True,
):
self.base_url = base_url
self.api_key = api_key
self.timeout = timeout
self.verify_ssl = verify_ssl
def get_headers(self) -> Dict[str, str]:
"""Get headers for API requests, including authentication if available"""
headers = {"Content-Type": "application/json", "Accept": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
return headers
def request(
self,
method: str,
path: str,
params: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
files: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
"""
Make an HTTP request to the API
Args:
method: HTTP method (GET, POST, etc.)
path: API endpoint path (will be joined with base_url)
params: Query parameters
json: JSON body data
files: Files to upload
headers: Additional headers
Returns:
Parsed JSON response
Raises:
requests.RequestException: If the request fails
"""
url = urljoin(self.base_url, path)
self.check_auth_token(self.api_key)
# Combine default headers with any provided headers
request_headers = self.get_headers()
if headers:
request_headers.update(headers)
# Let requests handle the content type when files are present.
if files:
del request_headers["Content-Type"]
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
logging.debug(f"[DEBUG] Files: {files}")
logging.debug(f"[DEBUG] Params: {params}")
logging.debug(f"[DEBUG] Json: {json}")
try:
# If files are present, use data parameter instead of json
if files:
form_data = {}
if json:
form_data.update(json)
response = requests.request(
method=method,
url=url,
params=params,
data=form_data, # Use data instead of json
files=files,
headers=request_headers,
timeout=self.timeout,
verify=self.verify_ssl,
)
else:
response = requests.request(
method=method,
url=url,
params=params,
json=json,
headers=request_headers,
timeout=self.timeout,
verify=self.verify_ssl,
)
# Raise exception for error status codes
response.raise_for_status()
except requests.ConnectionError:
raise Exception(
f"Unable to connect to the API server at {self.base_url}. Please check your internet connection or verify the service is available."
)
except requests.Timeout:
raise Exception(
f"Request timed out after {self.timeout} seconds. The server might be experiencing high load or the operation is taking longer than expected."
)
except requests.HTTPError as e:
status_code = e.response.status_code if hasattr(e, "response") else None
error_message = f"HTTP Error: {str(e)}"
# Try to extract detailed error message from JSON response
try:
if hasattr(e, "response") and e.response.content:
error_json = e.response.json()
if "error" in error_json and "message" in error_json["error"]:
error_message = f"API Error: {error_json['error']['message']}"
if "type" in error_json["error"]:
error_message += f" (Type: {error_json['error']['type']})"
else:
error_message = f"API Error: {error_json}"
except Exception as json_error:
# If we can't parse the JSON, fall back to the original error message
logging.debug(f"[DEBUG] Failed to parse error response: {str(json_error)}")
logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})")
if hasattr(e, "response") and e.response.content:
logging.debug(f"[DEBUG] Response content: {e.response.content}")
if status_code == 401:
error_message = "Unauthorized: Please login first to use this node."
if status_code == 402:
error_message = "Payment Required: Please add credits to your account to use this node."
if status_code == 409:
error_message = "There is a problem with your account. Please contact support@comfy.org. "
if status_code == 429:
error_message = "Rate Limit Exceeded: Please try again later."
raise Exception(error_message)
# Parse and return JSON response
if response.content:
return response.json()
return {}
def check_auth_token(self, auth_token):
"""Verify that an auth token is present."""
if auth_token is None:
raise Exception("Unauthorized: Please login first to use this node.")
return auth_token
class ApiEndpoint(Generic[T, R]):
"""Defines an API endpoint with its request and response types"""
def __init__(
self,
path: str,
method: HttpMethod,
request_model: Type[T],
response_model: Type[R],
query_params: Optional[Dict[str, Any]] = None,
):
"""Initialize an API endpoint definition.
Args:
path: The URL path for this endpoint, can include placeholders like {id}
method: The HTTP method to use (GET, POST, etc.)
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
query_params: Optional dictionary of query parameters to include in the request
"""
self.path = path
self.method = method
self.request_model = request_model
self.response_model = response_model
self.query_params = query_params or {}
class SynchronousOperation(Generic[T, R]):
"""
Represents a single synchronous API operation.
"""
def __init__(
self,
endpoint: ApiEndpoint[T, R],
request: T,
files: Optional[Dict[str, Any]] = None,
api_base: str = "https://api.comfy.org",
auth_token: Optional[str] = None,
timeout: float = 604800.0,
verify_ssl: bool = True,
):
self.endpoint = endpoint
self.request = request
self.response = None
self.error = None
self.api_base = api_base
self.auth_token = auth_token
self.timeout = timeout
self.verify_ssl = verify_ssl
self.files = files
def execute(self, client: Optional[ApiClient] = None) -> R:
"""Execute the API operation using the provided client or create one"""
try:
# Create client if not provided
if client is None:
if self.api_base is None:
raise ValueError("Either client or api_base must be provided")
client = ApiClient(
base_url=self.api_base,
api_key=self.auth_token,
timeout=self.timeout,
verify_ssl=self.verify_ssl,
)
# Convert request model to dict, but use None for EmptyRequest
request_dict = None if isinstance(self.request, EmptyRequest) else self.request.model_dump(exclude_none=True)
# Debug log for request
logging.debug(f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}")
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
# Make the request
resp = client.request(
method=self.endpoint.method.value,
path=self.endpoint.path,
json=request_dict,
params=self.endpoint.query_params,
files=self.files,
)
# Debug log for response
logging.debug("=" * 50)
logging.debug("[DEBUG] RESPONSE DETAILS:")
logging.debug("[DEBUG] Status Code: 200 (Success)")
logging.debug(f"[DEBUG] Response Body: {json.dumps(resp, indent=2)}")
logging.debug("=" * 50)
# Parse and return the response
return self._parse_response(resp)
except Exception as e:
logging.debug(f"[DEBUG] API Exception: {str(e)}")
raise Exception(str(e))
def _parse_response(self, resp):
"""Parse response data - can be overridden by subclasses"""
# The response is already the complete object, don't extract just the "data" field
# as that would lose the outer structure (created timestamp, etc.)
# Parse response using the provided model
self.response = self.endpoint.response_model.model_validate(resp)
logging.debug(f"[DEBUG] Parsed Response: {self.response}")
return self.response

View File

@@ -1,442 +0,0 @@
import io
from inspect import cleandoc
from comfy.utils import common_upscale
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from comfy_api_nodes.apis import (
OpenAIImageGenerationRequest,
OpenAIImageEditRequest,
OpenAIImageGenerationResponse
)
from comfy_api_nodes.apis.client import ApiEndpoint, HttpMethod, SynchronousOperation
import numpy as np
from PIL import Image
import requests
import torch
import math
import base64
def downscale_input(image):
samples = image.movedim(-1,1)
#downscaling input images to roughly the same size as the outputs
total = int(1536 * 1024)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
if scale_by >= 1:
return image
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = common_upscale(samples, width, height, "lanczos", "disabled")
s = s.movedim(1,-1)
return s
def validate_and_cast_response(response):
# validate raw JSON response
data = response.data
if not data or len(data) == 0:
raise Exception("No images returned from API endpoint")
# Initialize list to store image tensors
image_tensors = []
# Process each image in the data array
for image_data in data:
image_url = image_data.url
b64_data = image_data.b64_json
if not image_url and not b64_data:
raise Exception("No image was generated in the response")
if b64_data:
img_data = base64.b64decode(b64_data)
img = Image.open(io.BytesIO(img_data))
elif image_url:
img_response = requests.get(image_url)
if img_response.status_code != 200:
raise Exception("Failed to download the image")
img = Image.open(io.BytesIO(img_response.content))
img = img.convert("RGBA")
# Convert to numpy array, normalize to float32 between 0 and 1
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)
# Add to list of tensors
image_tensors.append(img_tensor)
return torch.stack(image_tensors, dim=0)
class OpenAIDalle2(ComfyNodeABC):
"""
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
Uses the proxy at /proxy/openai/images/generations. Returned URLs are shortlived,
so download or cache results if you need to keep them.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (IO.STRING, {
"multiline": True,
"default": "",
"tooltip": "Text prompt for DALL·E",
}),
},
"optional": {
"seed": (IO.INT, {
"default": 0,
"min": 0,
"max": 2**31-1,
"step": 1,
"display": "number",
"tooltip": "not implemented yet in backend",
}),
"size": (IO.COMBO, {
"options": ["256x256", "512x512", "1024x1024"],
"default": "1024x1024",
"tooltip": "Image size",
}),
"n": (IO.INT, {
"default": 1,
"min": 1,
"max": 8,
"step": 1,
"display": "number",
"tooltip": "How many images to generate",
}),
"image": (IO.IMAGE, {
"default": None,
"tooltip": "Optional reference image for image editing.",
}),
"mask": (IO.MASK, {
"default": None,
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
}),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG"
}
}
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
def api_call(self, prompt, seed=0, image=None, mask=None, n=1, size="1024x1024", auth_token=None):
model = "dall-e-2"
path = "/proxy/openai/images/generations"
request_class = OpenAIImageGenerationRequest
img_binary = None
if image is not None and mask is not None:
path = "/proxy/openai/images/edits"
request_class = OpenAIImageEditRequest
input_tensor = image.squeeze().cpu()
height, width, channels = input_tensor.shape
rgba_tensor = torch.ones(height, width, 4, device="cpu")
rgba_tensor[:, :, :channels] = input_tensor
if mask.shape[1:] != image.shape[1:-1]:
raise Exception("Mask and Image must be the same size")
rgba_tensor[:,:,3] = (1-mask.squeeze().cpu())
rgba_tensor = downscale_input(rgba_tensor.unsqueeze(0)).squeeze()
image_np = (rgba_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
img_binary = img_byte_arr#.getvalue()
img_binary.name = "image.png"
elif image is not None or mask is not None:
raise Exception("Dall-E 2 image editing requires an image AND a mask")
# Build the operation
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=request_class,
response_model=OpenAIImageGenerationResponse
),
request=request_class(
model=model,
prompt=prompt,
n=n,
size=size,
seed=seed,
),
files={
"image": img_binary,
} if img_binary else None,
auth_token=auth_token
)
response = operation.execute()
img_tensor = validate_and_cast_response(response)
return (img_tensor,)
class OpenAIDalle3(ComfyNodeABC):
"""
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
Uses the proxy at /proxy/openai/images/generations. Returned URLs are shortlived,
so download or cache results if you need to keep them.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (IO.STRING, {
"multiline": True,
"default": "",
"tooltip": "Text prompt for DALL·E",
}),
},
"optional": {
"seed": (IO.INT, {
"default": 0,
"min": 0,
"max": 2**31-1,
"step": 1,
"display": "number",
"tooltip": "not implemented yet in backend",
}),
"quality" : (IO.COMBO, {
"options": ["standard","hd"],
"default": "standard",
"tooltip": "Image quality",
}),
"style": (IO.COMBO, {
"options": ["natural","vivid"],
"default": "natural",
"tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.",
}),
"size": (IO.COMBO, {
"options": ["1024x1024", "1024x1792", "1792x1024"],
"default": "1024x1024",
"tooltip": "Image size",
}),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG"
}
}
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
def api_call(self, prompt, seed=0, style="natural", quality="standard", size="1024x1024", auth_token=None):
model = "dall-e-3"
# build the operation
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/openai/images/generations",
method=HttpMethod.POST,
request_model=OpenAIImageGenerationRequest,
response_model=OpenAIImageGenerationResponse
),
request=OpenAIImageGenerationRequest(
model=model,
prompt=prompt,
quality=quality,
size=size,
style=style,
seed=seed,
),
auth_token=auth_token
)
response = operation.execute()
img_tensor = validate_and_cast_response(response)
return (img_tensor,)
class OpenAIGPTImage1(ComfyNodeABC):
"""
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
Uses the proxy at /proxy/openai/images/generations. Returned URLs are shortlived,
so download or cache results if you need to keep them.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (IO.STRING, {
"multiline": True,
"default": "",
"tooltip": "Text prompt for GPT Image 1",
}),
},
"optional": {
"seed": (IO.INT, {
"default": 0,
"min": 0,
"max": 2**31-1,
"step": 1,
"display": "number",
"tooltip": "not implemented yet in backend",
}),
"quality": (IO.COMBO, {
"options": ["low","medium","high"],
"default": "low",
"tooltip": "Image quality, affects cost and generation time.",
}),
"background": (IO.COMBO, {
"options": ["opaque","transparent"],
"default": "opaque",
"tooltip": "Return image with or without background",
}),
"size": (IO.COMBO, {
"options": ["auto", "1024x1024", "1024x1536", "1536x1024"],
"default": "auto",
"tooltip": "Image size",
}),
"n": (IO.INT, {
"default": 1,
"min": 1,
"max": 8,
"step": 1,
"display": "number",
"tooltip": "How many images to generate",
}),
"image": (IO.IMAGE, {
"default": None,
"tooltip": "Optional reference image for image editing.",
}),
"mask": (IO.MASK, {
"default": None,
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
}),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG"
}
}
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
def api_call(self, prompt, seed=0, quality="low", background="opaque", image=None, mask=None, n=1, size="1024x1024", auth_token=None):
model = "gpt-image-1"
path = "/proxy/openai/images/generations"
request_class = OpenAIImageGenerationRequest
img_binaries = []
mask_binary = None
files = []
if image is not None:
path = "/proxy/openai/images/edits"
request_class = OpenAIImageEditRequest
batch_size = image.shape[0]
for i in range(batch_size):
single_image = image[i:i+1]
scaled_image = downscale_input(single_image).squeeze()
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
img_binary = img_byte_arr
img_binary.name = f"image_{i}.png"
img_binaries.append(img_binary)
if batch_size == 1:
files.append(("image", img_binary))
else:
files.append(("image[]", img_binary))
if mask is not None:
if image.shape[0] != 1:
raise Exception("Cannot use a mask with multiple image")
if image is None:
raise Exception("Cannot use a mask without an input image")
if mask.shape[1:] != image.shape[1:-1]:
raise Exception("Mask and Image must be the same size")
batch, height, width = mask.shape
rgba_mask = torch.zeros(height, width, 4, device="cpu")
rgba_mask[:,:,3] = (1-mask.squeeze().cpu())
scaled_mask = downscale_input(rgba_mask.unsqueeze(0)).squeeze()
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_img_byte_arr = io.BytesIO()
mask_img.save(mask_img_byte_arr, format='PNG')
mask_img_byte_arr.seek(0)
mask_binary = mask_img_byte_arr
mask_binary.name = "mask.png"
files.append(("mask", mask_binary))
# Build the operation
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=request_class,
response_model=OpenAIImageGenerationResponse
),
request=request_class(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=n,
seed=seed,
size=size,
),
files=files if files else None,
auth_token=auth_token
)
response = operation.execute()
img_tensor = validate_and_cast_response(response)
return (img_tensor,)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"OpenAIDalle2": OpenAIDalle2,
"OpenAIDalle3": OpenAIDalle3,
"OpenAIGPTImage1": OpenAIGPTImage1,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"OpenAIDalle2": "OpenAI DALL·E 2",
"OpenAIDalle3": "OpenAI DALL·E 3",
"OpenAIGPTImage1": "OpenAI GPT Image 1",
}

View File

@@ -316,156 +316,3 @@ class LRUCache(BasicCache):
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self
class DependencyAwareCache(BasicCache):
"""
A cache implementation that tracks dependencies between nodes and manages
their execution and caching accordingly. It extends the BasicCache class.
Nodes are removed from this cache once all of their descendants have been
executed.
"""
def __init__(self, key_class):
"""
Initialize the DependencyAwareCache.
Args:
key_class: The class used for generating cache keys.
"""
super().__init__(key_class)
self.descendants = {} # Maps node_id -> set of descendant node_ids
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
self.executed_nodes = set() # Tracks nodes that have been executed
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
"""
Clear the entire cache and rebuild the dependency graph.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to initialize the cache for.
is_changed_cache: Flag indicating if the cache has changed.
"""
# Clear all existing cache data
self.cache.clear()
self.subcaches.clear()
self.descendants.clear()
self.ancestors.clear()
self.executed_nodes.clear()
# Call the parent method to initialize the cache with the new prompt
super().set_prompt(dynprompt, node_ids, is_changed_cache)
# Rebuild the dependency graph
self._build_dependency_graph(dynprompt, node_ids)
def _build_dependency_graph(self, dynprompt, node_ids):
"""
Build the dependency graph for all nodes.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to build the graph for.
"""
self.descendants.clear()
self.ancestors.clear()
for node_id in node_ids:
self.descendants[node_id] = set()
self.ancestors[node_id] = set()
for node_id in node_ids:
inputs = dynprompt.get_node(node_id)["inputs"]
for input_data in inputs.values():
if is_link(input_data): # Check if the input is a link to another node
ancestor_id = input_data[0]
self.descendants[ancestor_id].add(node_id)
self.ancestors[node_id].add(ancestor_id)
def set(self, node_id, value):
"""
Mark a node as executed and store its value in the cache.
Args:
node_id: The ID of the node to store.
value: The value to store for the node.
"""
self._set_immediate(node_id, value)
self.executed_nodes.add(node_id)
self._cleanup_ancestors(node_id)
def get(self, node_id):
"""
Retrieve the cached value for a node.
Args:
node_id: The ID of the node to retrieve.
Returns:
The cached value for the node.
"""
return self._get_immediate(node_id)
def ensure_subcache_for(self, node_id, children_ids):
"""
Ensure a subcache exists for a node and update dependencies.
Args:
node_id: The ID of the parent node.
children_ids: List of child node IDs to associate with the parent node.
Returns:
The subcache object for the node.
"""
subcache = super()._ensure_subcache(node_id, children_ids)
for child_id in children_ids:
self.descendants[node_id].add(child_id)
self.ancestors[child_id].add(node_id)
return subcache
def _cleanup_ancestors(self, node_id):
"""
Check if ancestors of a node can be removed from the cache.
Args:
node_id: The ID of the node whose ancestors are to be checked.
"""
for ancestor_id in self.ancestors.get(node_id, []):
if ancestor_id in self.executed_nodes:
# Remove ancestor if all its descendants have been executed
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
self._remove_node(ancestor_id)
def _remove_node(self, node_id):
"""
Remove a node from the cache.
Args:
node_id: The ID of the node to remove.
"""
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
del self.cache[cache_key]
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
del self.subcaches[subcache_key]
def clean_unused(self):
"""
Clean up unused nodes. This is a no-op for this cache implementation.
"""
pass
def recursive_debug_dump(self):
"""
Dump the cache and dependency graph for debugging.
Returns:
A list containing the cache state and dependency graph.
"""
result = super().recursive_debug_dump()
result.append({
"descendants": self.descendants,
"ancestors": self.ancestors,
"executed_nodes": list(self.executed_nodes),
})
return result

View File

@@ -1,45 +0,0 @@
import torch
# https://github.com/WeichenFan/CFG-Zero-star
def optimized_scale(positive, negative):
positive_flat = positive.reshape(positive.shape[0], -1)
negative_flat = negative.reshape(negative.shape[0], -1)
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
st_star = dot_product / squared_norm
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
class CFGZeroStar:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "advanced/guidance"
def patch(self, model):
m = model.clone()
def cfg_zero_star(args):
guidance_scale = args['cond_scale']
x = args['input']
cond_p = args['cond_denoised']
uncond_p = args['uncond_denoised']
out = args["denoised"]
alpha = optimized_scale(x - cond_p, x - uncond_p)
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
m.set_model_sampler_post_cfg_function(cfg_zero_star)
return (m, )
NODE_CLASS_MAPPINGS = {
"CFGZeroStar": CFGZeroStar
}

View File

@@ -1,100 +0,0 @@
# Code based on https://github.com/WikiChao/FreSca (MIT License)
import torch
import torch.fft as fft
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
"""
Apply frequency-dependent scaling to an image tensor using Fourier transforms.
Parameters:
x: Input tensor of shape (B, C, H, W)
scale_low: Scaling factor for low-frequency components (default: 1.0)
scale_high: Scaling factor for high-frequency components (default: 1.5)
freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20)
Returns:
x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied.
"""
# Preserve input dtype and device
dtype, device = x.dtype, x.device
# Convert to float32 for FFT computations
x = x.to(torch.float32)
# 1) Apply FFT and shift low frequencies to center
x_freq = fft.fftn(x, dim=(-2, -1))
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
# Initialize mask with high-frequency scaling factor
mask = torch.ones(x_freq.shape, device=device) * scale_high
m = mask
for d in range(len(x_freq.shape) - 2):
dim = d + 2
cc = x_freq.shape[dim] // 2
f_c = min(freq_cutoff, cc)
m = m.narrow(dim, cc - f_c, f_c * 2)
# Apply low-frequency scaling factor to center region
m[:] = scale_low
# 3) Apply frequency-specific scaling
x_freq = x_freq * mask
# 4) Convert back to spatial domain
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
# 5) Restore original dtype
x_filtered = x_filtered.to(dtype)
return x_filtered
class FreSca:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01,
"tooltip": "Scaling factor for low-frequency components"}),
"scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01,
"tooltip": "Scaling factor for high-frequency components"}),
"freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1,
"tooltip": "Number of frequency indices around center to consider as low-frequency"}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
def patch(self, model, scale_low, scale_high, freq_cutoff):
def custom_cfg_function(args):
cond = args["conds_out"][0]
uncond = args["conds_out"][1]
guidance = cond - uncond
filtered_guidance = Fourier_filter(
guidance,
scale_low=scale_low,
scale_high=scale_high,
freq_cutoff=freq_cutoff,
)
filtered_cond = filtered_guidance + uncond
return [filtered_cond, uncond]
m = model.clone()
m.set_model_sampler_pre_cfg_function(custom_cfg_function)
return (m,)
NODE_CLASS_MAPPINGS = {
"FreSca": FreSca,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"FreSca": "FreSca",
}

View File

@@ -1,55 +0,0 @@
import folder_paths
import comfy.sd
import comfy.model_management
class QuadrupleCLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name3": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name4": (folder_paths.get_filename_list("text_encoders"), )
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct"
def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
class CLIPTextEncodeHiDream:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"llama": ("STRING", {"multiline": True, "dynamicPrompts": True})
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, clip_l, clip_g, t5xxl, llama):
tokens = clip.tokenize(clip_g)
tokens["l"] = clip.tokenize(clip_l)["l"]
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
tokens["llama"] = clip.tokenize(llama)["llama"]
return (clip.encode_from_tokens_scheduled(tokens), )
NODE_CLASS_MAPPINGS = {
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
"CLIPTextEncodeHiDream": CLIPTextEncodeHiDream,
}

View File

@@ -209,196 +209,6 @@ def voxel_to_mesh(voxels, threshold=0.5, device=None):
vertices = torch.fliplr(vertices)
return vertices, faces
def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
if device is None:
device = torch.device("cpu")
voxels = voxels.to(device)
D, H, W = voxels.shape
padded = torch.nn.functional.pad(voxels, (1, 1, 1, 1, 1, 1), 'constant', 0)
z, y, x = torch.meshgrid(
torch.arange(D, device=device),
torch.arange(H, device=device),
torch.arange(W, device=device),
indexing='ij'
)
cell_positions = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
corner_offsets = torch.tensor([
[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0],
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
], device=device)
corner_values = torch.zeros((cell_positions.shape[0], 8), device=device)
for c, (dz, dy, dx) in enumerate(corner_offsets):
corner_values[:, c] = padded[
cell_positions[:, 0] + dz,
cell_positions[:, 1] + dy,
cell_positions[:, 2] + dx
]
corner_signs = corner_values > threshold
has_inside = torch.any(corner_signs, dim=1)
has_outside = torch.any(~corner_signs, dim=1)
contains_surface = has_inside & has_outside
active_cells = cell_positions[contains_surface]
active_signs = corner_signs[contains_surface]
active_values = corner_values[contains_surface]
if active_cells.shape[0] == 0:
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
edges = torch.tensor([
[0, 1], [0, 2], [0, 4], [1, 3],
[1, 5], [2, 3], [2, 6], [3, 7],
[4, 5], [4, 6], [5, 7], [6, 7]
], device=device)
cell_vertices = {}
progress = comfy.utils.ProgressBar(100)
for edge_idx, (e1, e2) in enumerate(edges):
progress.update(1)
crossing = active_signs[:, e1] != active_signs[:, e2]
if not crossing.any():
continue
cell_indices = torch.nonzero(crossing, as_tuple=True)[0]
v1 = active_values[cell_indices, e1]
v2 = active_values[cell_indices, e2]
t = torch.zeros_like(v1, device=device)
denom = v2 - v1
valid = denom != 0
t[valid] = (threshold - v1[valid]) / denom[valid]
t[~valid] = 0.5
p1 = corner_offsets[e1].float()
p2 = corner_offsets[e2].float()
intersection = p1.unsqueeze(0) + t.unsqueeze(1) * (p2.unsqueeze(0) - p1.unsqueeze(0))
for i, point in zip(cell_indices.tolist(), intersection):
if i not in cell_vertices:
cell_vertices[i] = []
cell_vertices[i].append(point)
# Calculate the final vertices as the average of intersection points for each cell
vertices = []
vertex_lookup = {}
vert_progress_mod = round(len(cell_vertices)/50)
for i, points in cell_vertices.items():
if not i % vert_progress_mod:
progress.update(1)
if points:
vertex = torch.stack(points).mean(dim=0)
vertex = vertex + active_cells[i].float()
vertex_lookup[tuple(active_cells[i].tolist())] = len(vertices)
vertices.append(vertex)
if not vertices:
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
final_vertices = torch.stack(vertices)
inside_corners_mask = active_signs
outside_corners_mask = ~active_signs
inside_counts = inside_corners_mask.sum(dim=1, keepdim=True).float()
outside_counts = outside_corners_mask.sum(dim=1, keepdim=True).float()
inside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
outside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
for i in range(8):
mask_inside = inside_corners_mask[:, i].unsqueeze(1)
mask_outside = outside_corners_mask[:, i].unsqueeze(1)
inside_pos += corner_offsets[i].float().unsqueeze(0) * mask_inside
outside_pos += corner_offsets[i].float().unsqueeze(0) * mask_outside
inside_pos /= inside_counts
outside_pos /= outside_counts
gradients = inside_pos - outside_pos
pos_dirs = torch.tensor([
[1, 0, 0],
[0, 1, 0],
[0, 0, 1]
], device=device)
cross_products = [
torch.linalg.cross(pos_dirs[i].float(), pos_dirs[j].float())
for i in range(3) for j in range(i+1, 3)
]
faces = []
all_keys = set(vertex_lookup.keys())
face_progress_mod = round(len(active_cells)/38*3)
for pair_idx, (i, j) in enumerate([(0,1), (0,2), (1,2)]):
dir_i = pos_dirs[i]
dir_j = pos_dirs[j]
cross_product = cross_products[pair_idx]
ni_positions = active_cells + dir_i
nj_positions = active_cells + dir_j
diag_positions = active_cells + dir_i + dir_j
alignments = torch.matmul(gradients, cross_product)
valid_quads = []
quad_indices = []
for idx, active_cell in enumerate(active_cells):
if not idx % face_progress_mod:
progress.update(1)
cell_key = tuple(active_cell.tolist())
ni_key = tuple(ni_positions[idx].tolist())
nj_key = tuple(nj_positions[idx].tolist())
diag_key = tuple(diag_positions[idx].tolist())
if cell_key in all_keys and ni_key in all_keys and nj_key in all_keys and diag_key in all_keys:
v0 = vertex_lookup[cell_key]
v1 = vertex_lookup[ni_key]
v2 = vertex_lookup[nj_key]
v3 = vertex_lookup[diag_key]
valid_quads.append((v0, v1, v2, v3))
quad_indices.append(idx)
for q_idx, (v0, v1, v2, v3) in enumerate(valid_quads):
cell_idx = quad_indices[q_idx]
if alignments[cell_idx] > 0:
faces.append(torch.tensor([v0, v1, v3], device=device, dtype=torch.long))
faces.append(torch.tensor([v0, v3, v2], device=device, dtype=torch.long))
else:
faces.append(torch.tensor([v0, v3, v1], device=device, dtype=torch.long))
faces.append(torch.tensor([v0, v2, v3], device=device, dtype=torch.long))
if faces:
faces = torch.stack(faces)
else:
faces = torch.zeros((0, 3), dtype=torch.long, device=device)
v_min = 0
v_max = max(D, H, W)
final_vertices = final_vertices - (v_min + v_max) / 2
scale = (v_max - v_min) / 2
if scale > 0:
final_vertices = final_vertices / scale
final_vertices = torch.fliplr(final_vertices)
return final_vertices, faces
class MESH:
def __init__(self, vertices, faces):
@@ -427,34 +237,6 @@ class VoxelToMeshBasic:
return (MESH(torch.stack(vertices), torch.stack(faces)), )
class VoxelToMesh:
@classmethod
def INPUT_TYPES(s):
return {"required": {"voxel": ("VOXEL", ),
"algorithm": (["surface net", "basic"], ),
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MESH",)
FUNCTION = "decode"
CATEGORY = "3d"
def decode(self, voxel, algorithm, threshold):
vertices = []
faces = []
if algorithm == "basic":
mesh_function = voxel_to_mesh
elif algorithm == "surface net":
mesh_function = voxel_to_mesh_surfnet
for x in voxel.data:
v, f = mesh_function(x, threshold=threshold, device=None)
vertices.append(v)
faces.append(f)
return (MESH(torch.stack(vertices), torch.stack(faces)), )
def save_glb(vertices, faces, filepath, metadata=None):
"""
@@ -462,7 +244,7 @@ def save_glb(vertices, faces, filepath, metadata=None):
Parameters:
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces)
faces: torch.Tensor of shape (M, 4) or (M, 3) - The face indices (quad or triangle faces)
filepath: str - Output filepath (should end with .glb)
"""
@@ -629,6 +411,5 @@ NODE_CLASS_MAPPINGS = {
"Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView,
"VAEDecodeHunyuan3D": VAEDecodeHunyuan3D,
"VoxelToMeshBasic": VoxelToMeshBasic,
"VoxelToMesh": VoxelToMesh,
"SaveGLB": SaveGLB,
}

View File

@@ -21,8 +21,8 @@ class Load3D():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart")
FUNCTION = "process"
EXPERIMENTAL = True
@@ -41,7 +41,7 @@ class Load3D():
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info']
return output_image, output_mask, model_file, normal_image, lineart_image
class Load3DAnimation():
@classmethod
@@ -59,8 +59,8 @@ class Load3DAnimation():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal")
FUNCTION = "process"
EXPERIMENTAL = True
@@ -77,16 +77,13 @@ class Load3DAnimation():
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
return output_image, output_mask, model_file, normal_image, image['camera_info']
return output_image, output_mask, model_file, normal_image
class Preview3D():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
},
"optional": {
"camera_info": ("LOAD3D_CAMERA", {})
}}
OUTPUT_NODE = True
@@ -98,22 +95,13 @@ class Preview3D():
EXPERIMENTAL = True
def process(self, model_file, **kwargs):
camera_info = kwargs.get("camera_info", None)
return {
"ui": {
"result": [model_file, camera_info]
}
}
return {"ui": {"model_file": [model_file]}, "result": ()}
class Preview3DAnimation():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
},
"optional": {
"camera_info": ("LOAD3D_CAMERA", {})
}}
OUTPUT_NODE = True
@@ -125,13 +113,7 @@ class Preview3DAnimation():
EXPERIMENTAL = True
def process(self, model_file, **kwargs):
camera_info = kwargs.get("camera_info", None)
return {
"ui": {
"result": [model_file, camera_info]
}
}
return {"ui": {"model_file": [model_file]}, "result": ()}
NODE_CLASS_MAPPINGS = {
"Load3D": Load3D,

View File

@@ -385,7 +385,7 @@ def encode_single_frame(output_file, image_array: np.ndarray, crf):
container = av.open(output_file, "w", format="mp4")
try:
stream = container.add_stream(
"libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
"h264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
)
stream.height = image_array.shape[0]
stream.width = image_array.shape[1]
@@ -446,6 +446,7 @@ class LTXVPreprocess:
CATEGORY = "image"
def preprocess(self, image, img_compression):
if img_compression > 0:
output_images = []
for i in range(image.shape[0]):
output_images.append(preprocess(image[i], img_compression))

View File

@@ -2,11 +2,7 @@ import numpy as np
import scipy.ndimage
import torch
import comfy.utils
import node_helpers
import folder_paths
import random
import nodes
from nodes import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
@@ -91,7 +87,6 @@ class ImageCompositeMasked:
CATEGORY = "image"
def composite(self, destination, source, x, y, resize_source, mask = None):
destination, source = node_helpers.image_alpha_fix(destination, source)
destination = destination.clone().movedim(-1, 1)
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
return (output,)
@@ -365,30 +360,6 @@ class ThresholdMask:
mask = (mask > value).float()
return (mask,)
# Mask Preview - original implement from
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
class MaskPreview(nodes.SaveImage):
def __init__(self):
self.output_dir = folder_paths.get_temp_directory()
self.type = "temp"
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
self.compress_level = 4
@classmethod
def INPUT_TYPES(s):
return {
"required": {"mask": ("MASK",), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
FUNCTION = "execute"
CATEGORY = "mask"
def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
NODE_CLASS_MAPPINGS = {
"LatentCompositeMasked": LatentCompositeMasked,
@@ -403,7 +374,6 @@ NODE_CLASS_MAPPINGS = {
"FeatherMask": FeatherMask,
"GrowMask": GrowMask,
"ThresholdMask": ThresholdMask,
"MaskPreview": MaskPreview
}
NODE_DISPLAY_NAME_MAPPINGS = {

View File

@@ -1,56 +0,0 @@
# from https://github.com/bebebe666/OptimalSteps
import numpy as np
import torch
def loglinear_interp(t_steps, num_steps):
"""
Performs log-linear interpolation of a given array of decreasing numbers.
"""
xs = np.linspace(0, 1, len(t_steps))
ys = np.log(t_steps[::-1])
new_xs = np.linspace(0, 1, num_steps)
new_ys = np.interp(new_xs, xs, ys)
interped_ys = np.exp(new_ys)[::-1].copy()
return interped_ys
NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001],
"Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001],
}
class OptimalStepsScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model_type": (["FLUX", "Wan"], ),
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, model_type, steps, denoise):
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
total_steps = round(steps * denoise)
sigmas = NOISE_LEVELS[model_type][:]
if (steps + 1) != len(sigmas):
sigmas = loglinear_interp(sigmas, steps + 1)
sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0
return (torch.FloatTensor(sigmas), )
NODE_CLASS_MAPPINGS = {
"OptimalStepsScheduler": OptimalStepsScheduler,
}

View File

@@ -6,7 +6,7 @@ import math
import comfy.utils
import comfy.model_management
import node_helpers
class Blend:
def __init__(self):
@@ -34,7 +34,6 @@ class Blend:
CATEGORY = "image/postprocessing"
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
image1, image2 = node_helpers.image_alpha_fix(image1, image2)
image2 = image2.to(image1.device)
if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2)

View File

@@ -1,8 +1,6 @@
# Primitive nodes that are evaluated at backend.
from __future__ import annotations
import sys
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
@@ -25,7 +23,7 @@ class Int(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})},
"required": {"value": (IO.INT, {"control_after_generate": True})},
}
RETURN_TYPES = (IO.INT,)
@@ -40,7 +38,7 @@ class Float(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})},
"required": {"value": (IO.FLOAT, {})},
}
RETURN_TYPES = (IO.FLOAT,)

View File

@@ -50,15 +50,13 @@ class SaveWEBM:
for x in extra_pnginfo:
container.metadata[x] = json.dumps(extra_pnginfo[x])
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
codec_map = {"vp9": "libvpx-vp9", "av1": "libaom-av1"}
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
stream.width = images.shape[-2]
stream.height = images.shape[-3]
stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p"
stream.pix_fmt = "yuv420p"
stream.bit_rate = 0
stream.options = {'crf': str(crf)}
if codec == "av1":
stream.options["preset"] = "6"
for frame in images:
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")

View File

@@ -3,8 +3,6 @@ import node_helpers
import torch
import comfy.model_management
import comfy.utils
import comfy.latent_formats
import comfy.clip_vision
class WanImageToVideo:
@@ -51,258 +49,6 @@ class WanImageToVideo:
return (positive, negative, out_latent)
class WanFunControlToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"control_video": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(control_video[:, :, :, :3])
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanFirstLastFrameToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ),
"clip_vision_end_image": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if end_image is not None:
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, 3)) * 0.5
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
if start_image is not None:
image[:start_image.shape[0]] = start_image
mask[:, :, :start_image.shape[0] + 3] = 0.0
if end_image is not None:
image[-end_image.shape[0]:] = end_image
mask[:, :, -end_image.shape[0]:] = 0.0
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_start_image is not None:
clip_vision_output = clip_vision_start_image
if clip_vision_end_image is not None:
if clip_vision_output is not None:
states = torch.cat([clip_vision_output.penultimate_hidden_states, clip_vision_end_image.penultimate_hidden_states], dim=-2)
clip_vision_output = comfy.clip_vision.Output()
clip_vision_output.penultimate_hidden_states = states
else:
clip_vision_output = clip_vision_end_image
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanFunInpaintToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
flfv = WanFirstLastFrameToVideo()
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
class WanVaceToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
},
"optional": {"control_video": ("IMAGE", ),
"control_masks": ("MASK", ),
"reference_image": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT")
RETURN_NAMES = ("positive", "negative", "latent", "trim_latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
EXPERIMENTAL = True
def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None):
latent_length = ((length - 1) // 4) + 1
if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if control_video.shape[0] < length:
control_video = torch.nn.functional.pad(control_video, (0, 0, 0, 0, 0, 0, 0, length - control_video.shape[0]), value=0.5)
else:
control_video = torch.ones((length, height, width, 3)) * 0.5
if reference_image is not None:
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
reference_image = vae.encode(reference_image[:, :, :, :3])
reference_image = torch.cat([reference_image, comfy.latent_formats.Wan21().process_out(torch.zeros_like(reference_image))], dim=1)
if control_masks is None:
mask = torch.ones((length, height, width, 1))
else:
mask = control_masks
if mask.ndim == 3:
mask = mask.unsqueeze(1)
mask = comfy.utils.common_upscale(mask[:length], width, height, "bilinear", "center").movedim(1, -1)
if mask.shape[0] < length:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, 0, 0, length - mask.shape[0]), value=1.0)
control_video = control_video - 0.5
inactive = (control_video * (1 - mask)) + 0.5
reactive = (control_video * mask) + 0.5
inactive = vae.encode(inactive[:, :, :, :3])
reactive = vae.encode(reactive[:, :, :, :3])
control_video_latent = torch.cat((inactive, reactive), dim=1)
if reference_image is not None:
control_video_latent = torch.cat((reference_image, control_video_latent), dim=2)
vae_stride = 8
height_mask = height // vae_stride
width_mask = width // vae_stride
mask = mask.view(length, height_mask, vae_stride, width_mask, vae_stride)
mask = mask.permute(2, 4, 0, 1, 3)
mask = mask.reshape(vae_stride * vae_stride, length, height_mask, width_mask)
mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(latent_length, height_mask, width_mask), mode='nearest-exact').squeeze(0)
trim_latent = 0
if reference_image is not None:
mask_pad = torch.zeros_like(mask[:, :reference_image.shape[2], :, :])
mask = torch.cat((mask_pad, mask), dim=1)
latent_length += reference_image.shape[2]
trim_latent = reference_image.shape[2]
mask = mask.unsqueeze(0)
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent, trim_latent)
class TrimVideoLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/video"
EXPERIMENTAL = True
def op(self, samples, trim_amount):
samples_out = samples.copy()
s1 = samples["samples"]
samples_out["samples"] = s1[:, :, trim_amount:]
return (samples_out,)
NODE_CLASS_MAPPINGS = {
"WanImageToVideo": WanImageToVideo,
"WanFunControlToVideo": WanFunControlToVideo,
"WanFunInpaintToVideo": WanFunInpaintToVideo,
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
"WanVaceToVideo": WanVaceToVideo,
"TrimVideoLatent": TrimVideoLatent,
}

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.30"
__version__ = "0.3.27"

View File

@@ -15,7 +15,7 @@ import nodes
import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.validation import validate_node_input
class ExecutionResult(Enum):
@@ -59,45 +59,27 @@ class IsChangedCache:
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
class CacheType(Enum):
CLASSIC = 0
LRU = 1
DEPENDENCY_AWARE = 2
class CacheSet:
def __init__(self, cache_type=None, cache_size=None):
if cache_type == CacheType.DEPENDENCY_AWARE:
self.init_dependency_aware_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.LRU:
if cache_size is None:
cache_size = 0
self.init_lru_cache(cache_size)
logging.info("Using LRU cache")
else:
def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0:
self.init_classic_cache()
else:
self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects]
# Useful for those with ample RAM/VRAM -- allows experimenting without
# blowing away the cache every time
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# only hold cached items while the decendents have not executed
def init_dependency_aware_cache(self):
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
self.objects = DependencyAwareCache(CacheKeySetID)
def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
@@ -144,8 +126,6 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id]
if h[x] == "AUTH_TOKEN_COMFY_ORG":
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
return input_data_all, missing_keys
map_node_over_list = None #Don't hook this please
@@ -434,14 +414,13 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor:
def __init__(self, server, cache_type=False, cache_size=None):
self.cache_size = cache_size
self.cache_type = cache_type
def __init__(self, server, lru_size=None):
self.lru_size = lru_size
self.server = server
self.reset()
def reset(self):
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
self.caches = CacheSet(self.lru_size)
self.status_messages = []
self.success = True
@@ -797,7 +776,7 @@ def validate_prompt(prompt):
"details": f"Node ID '#{x}'",
"extra_info": {}
}
return (False, error, [], {})
return (False, error, [], [])
class_type = prompt[x]['class_type']
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
@@ -808,7 +787,7 @@ def validate_prompt(prompt):
"details": f"Node ID '#{x}'",
"extra_info": {}
}
return (False, error, [], {})
return (False, error, [], [])
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
outputs.add(x)
@@ -820,7 +799,7 @@ def validate_prompt(prompt):
"details": "",
"extra_info": {}
}
return (False, error, [], {})
return (False, error, [], [])
good_outputs = set()
errors = []

View File

@@ -85,7 +85,6 @@ cache_helper = CacheHelper()
extension_mimetypes_cache = {
"webp" : "image",
"fbx" : "model",
}
def map_legacy(folder_name: str) -> str:
@@ -141,14 +140,11 @@ def get_directory_by_type(type_name: str) -> str | None:
return get_input_directory()
return None
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio", "model"]) -> list[str]:
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]:
"""
Example:
files = os.listdir(folder_paths.get_input_directory())
videos = filter_files_content_types(files, ["video"])
Note:
- 'model' in MIME context refers to 3D models, not files containing trained weights and parameters
filter_files_content_types(files, ["image", "audio", "video"])
"""
global extension_mimetypes_cache
result = []

10
main.py
View File

@@ -10,7 +10,6 @@ from app.logger import setup_logger
import itertools
import utils.extra_config
import logging
import sys
if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes.
@@ -157,13 +156,7 @@ def cuda_malloc_warning():
def prompt_worker(q, server_instance):
current_time: float = 0.0
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_none:
cache_type = execution.CacheType.DEPENDENCY_AWARE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
@@ -302,7 +295,6 @@ def start_comfyui(asyncio_loop=None):
if __name__ == "__main__":
# Running directly, just start ComfyUI.
logging.info("Python version: {}".format(sys.version))
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
event_loop, _, start_all_func = start_comfyui()

View File

@@ -44,11 +44,3 @@ def string_to_torch_dtype(string):
return torch.float16
if string == "bf16":
return torch.bfloat16
def image_alpha_fix(destination, source):
if destination.shape[-1] < source.shape[-1]:
source = source[...,:destination.shape[-1]]
elif destination.shape[-1] > source.shape[-1]:
destination = torch.nn.functional.pad(destination, (0, 1))
destination[..., -1] = 1.0
return destination, source

View File

@@ -786,8 +786,6 @@ class ControlNetLoader:
def load_controlnet(self, control_net_name):
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
if controlnet is None:
raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.")
return (controlnet,)
class DiffControlNetLoader:
@@ -917,7 +915,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@@ -927,10 +925,29 @@ class CLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl"
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
if type == "stable_cascade":
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3
elif type == "stable_audio":
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
elif type == "mochi":
clip_type = comfy.sd.CLIPType.MOCHI
elif type == "ltxv":
clip_type = comfy.sd.CLIPType.LTXV
elif type == "pixart":
clip_type = comfy.sd.CLIPType.PIXART
elif type == "cosmos":
clip_type = comfy.sd.CLIPType.COSMOS
elif type == "lumina2":
clip_type = comfy.sd.CLIPType.LUMINA2
elif type == "wan":
clip_type = comfy.sd.CLIPType.WAN
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
model_options = {}
if device == "cpu":
@@ -945,7 +962,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ),
"type": (["sdxl", "sd3", "flux", "hunyuan_video"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@@ -955,13 +972,19 @@ class DualCLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama"
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
def load_clip(self, clip_name1, clip_name2, type, device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
if type == "sdxl":
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3
elif type == "flux":
clip_type = comfy.sd.CLIPType.FLUX
elif type == "hunyuan_video":
clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO
model_options = {}
if device == "cpu":
@@ -983,8 +1006,6 @@ class CLIPVisionLoader:
def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
clip_vision = comfy.clip_vision.load(clip_path)
if clip_vision is None:
raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.")
return (clip_vision,)
class CLIPVisionEncode:
@@ -1629,7 +1650,6 @@ class LoadImage:
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
files = folder_paths.filter_files_content_types(files, ["image"])
return {"required":
{"image": (sorted(files), {"image_upload": True})},
}
@@ -1668,9 +1688,6 @@ class LoadImage:
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
@@ -2106,25 +2123,21 @@ def get_module_name(module_path: str) -> str:
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
module_name = get_module_name(module_path)
module_name = os.path.basename(module_path)
if os.path.isfile(module_path):
sp = os.path.splitext(module_path)
module_name = sp[0]
sys_module_name = module_name
elif os.path.isdir(module_path):
sys_module_name = module_path.replace(".", "_x_")
try:
logging.debug("Trying to load custom node {}".format(module_path))
if os.path.isfile(module_path):
module_spec = importlib.util.spec_from_file_location(sys_module_name, module_path)
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
module_dir = os.path.split(module_path)[0]
else:
module_spec = importlib.util.spec_from_file_location(sys_module_name, os.path.join(module_path, "__init__.py"))
module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
module_dir = module_path
module = importlib.util.module_from_spec(module_spec)
sys.modules[sys_module_name] = module
sys.modules[module_name] = module
module_spec.loader.exec_module(module)
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
@@ -2254,15 +2267,6 @@ def init_builtin_extra_nodes():
"nodes_lotus.py",
"nodes_hunyuan3d.py",
"nodes_primitive.py",
"nodes_cfg.py",
"nodes_optimalsteps.py",
"nodes_hidream.py",
"nodes_fresca.py",
]
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
api_nodes_files = [
"nodes_api.py",
]
import_failed = []
@@ -2270,10 +2274,6 @@ def init_builtin_extra_nodes():
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
import_failed.append(node_file)
for node_file in api_nodes_files:
if not load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
import_failed.append(node_file)
return import_failed

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.30"
version = "0.3.27"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -1,5 +1,4 @@
comfyui-frontend-package==1.17.11
comfyui-workflow-templates==0.1.3
comfyui-frontend-package==1.14.5
torch
torchsde
torchvision
@@ -22,5 +21,4 @@ psutil
kornia>=0.7.1
spandrel
soundfile
av>=14.1.0
pydantic~=2.0
av

View File

@@ -48,7 +48,7 @@ async def send_socket_catch_exception(function, message):
@web.middleware
async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request)
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
if request.path.endswith('.js') or request.path.endswith('.css'):
response.headers.setdefault('Cache-Control', 'no-cache')
return response
@@ -580,9 +580,6 @@ class PromptServer():
info['deprecated'] = True
if getattr(obj_class, "EXPERIMENTAL", False):
info['experimental'] = True
if hasattr(obj_class, 'API_NODE'):
info['api_node'] = obj_class.API_NODE
return info
@routes.get("/object_info")
@@ -660,13 +657,7 @@ class PromptServer():
logging.warning("invalid prompt: {}".format(valid[1]))
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
else:
error = {
"type": "no_prompt",
"message": "No prompt provided",
"details": "No prompt provided",
"extra_info": {}
}
return web.json_response({"error": error, "node_errors": {}}, status=400)
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
@routes.post("/queue")
async def post_queue(request):
@@ -739,12 +730,6 @@ class PromptServer():
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
self.app.add_routes([web.static('/extensions/' + name, dir)])
workflow_templates_path = FrontendManager.templates_path()
if workflow_templates_path:
self.app.add_routes([
web.static('/templates', workflow_templates_path)
])
self.app.add_routes([
web.static('/', self.web_root),
])

View File

@@ -1,17 +1,14 @@
import pytest
import os
import tempfile
from folder_paths import filter_files_content_types, extension_mimetypes_cache
from unittest.mock import patch
from folder_paths import filter_files_content_types
@pytest.fixture(scope="module")
def file_extensions():
return {
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv'],
'model': ['gltf', 'glb', 'obj', 'fbx', 'stl']
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
}
@@ -25,18 +22,7 @@ def mock_dir(file_extensions):
yield directory
@pytest.fixture
def patched_mimetype_cache(file_extensions):
# Mock model file extensions since they may not be in the test-runner system's mimetype cache
new_cache = extension_mimetypes_cache.copy()
for extension in file_extensions["model"]:
new_cache[extension] = "model"
with patch("folder_paths.extension_mimetypes_cache", new_cache):
yield
def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_cache):
def test_categorizes_all_correctly(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
@@ -44,7 +30,7 @@ def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_c
assert f"sample_{content_type}.{extension}" in filtered_files
def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_cache):
def test_categorizes_all_uniquely(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])