Compare commits

..

2 Commits

Author SHA1 Message Date
pythongosssss
fde9fdddff Allow running with non working 2025-03-28 11:46:05 +08:00
pythongosssss
7bf381bc9e Add model management and database
- use sqlalchemy + alembic + sqlite for db
- extract model data and previews
- endpoints for db interactions
- add tests
2025-03-28 11:39:56 +08:00
189 changed files with 272208 additions and 4431 deletions

View File

@@ -12,7 +12,7 @@ on:
description: 'CUDA version'
required: true
type: string
default: "126"
default: "124"
python_minor:
description: 'Python minor version'
required: true
@@ -22,7 +22,7 @@ on:
description: 'Python patch version'
required: true
type: string
default: "9"
default: "8"
jobs:

View File

@@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@@ -28,4 +28,4 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements.txt

View File

@@ -18,7 +18,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip

58
.github/workflows/update-frontend.yml vendored Normal file
View File

@@ -0,0 +1,58 @@
name: Update Frontend Release
on:
workflow_dispatch:
inputs:
version:
description: "Frontend version to update to (e.g., 1.0.0)"
required: true
type: string
jobs:
update-frontend:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- name: Checkout ComfyUI
uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install wait-for-it
# Frontend asset will be downloaded to ComfyUI/web_custom_versions/Comfy-Org_ComfyUI_frontend/{version}
- name: Start ComfyUI server
run: |
python main.py --cpu --front-end-version Comfy-Org/ComfyUI_frontend@${{ github.event.inputs.version }} 2>&1 | tee console_output.log &
wait-for-it --service 127.0.0.1:8188 -t 30
- name: Configure Git
run: |
git config --global user.name "GitHub Action"
git config --global user.email "action@github.com"
# Replace existing frontend content with the new version and remove .js.map files
# See https://github.com/Comfy-Org/ComfyUI_frontend/issues/2145 for why we remove .js.map files
- name: Update frontend content
run: |
rm -rf web/
cp -r web_custom_versions/Comfy-Org_ComfyUI_frontend/${{ github.event.inputs.version }} web/
rm web/**/*.js.map
- name: Create Pull Request
uses: peter-evans/create-pull-request@v7
with:
token: ${{ secrets.PR_BOT_PAT }}
commit-message: "Update frontend to v${{ github.event.inputs.version }}"
title: "Frontend Update: v${{ github.event.inputs.version }}"
body: |
Automated PR to update frontend content to version ${{ github.event.inputs.version }}
This PR was created automatically by the frontend update workflow.
branch: release-${{ github.event.inputs.version }}
base: master
labels: Frontend,dependencies

View File

@@ -17,7 +17,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "126"
default: "124"
python_minor:
description: 'python minor version'
@@ -29,7 +29,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "9"
default: "8"
# push:
# branches:
# - master

View File

@@ -7,7 +7,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "126"
default: "124"
python_minor:
description: 'python minor version'
@@ -19,7 +19,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "9"
default: "8"
# push:
# branches:
# - master

View File

@@ -11,13 +11,13 @@
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
# Python web server
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
# Frontend assets
/web/ @huchenlei @webfiltered @pythongosssss @yoland68 @robinjhuang
# Extra nodes
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink

View File

@@ -1,7 +1,7 @@
<div align="center">
# ComfyUI
**The most powerful and modular visual AI engine and application.**
**The most powerful and modular diffusion model GUI and backend.**
[![Website][website-shield]][website-url]
@@ -31,24 +31,10 @@
![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe)
</div>
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
## Get Started
#### [Desktop Application](https://www.comfy.org/download)
- The easiest way to get started.
- Available on Windows & macOS.
#### [Windows Portable Package](#installing)
- Get the latest commits and completely portable.
- Available on Windows.
#### [Manual Install](#manual-install-windows-linux)
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
### [Installing ComfyUI](#installing)
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
@@ -61,14 +47,11 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
@@ -136,7 +119,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
# Installing
## Windows Portable
## Windows
There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases).
@@ -146,8 +129,6 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
If you have trouble extracting it, right click the file -> properties -> unblock
If you have a 50 series Blackwell card like a 5090 or 5080 see [this discussion thread](https://github.com/comfyanonymous/ComfyUI/discussions/6643)
#### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
@@ -156,18 +137,9 @@ See the [Config file](extra_model_paths.yaml.example) to set the search paths fo
To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb)
## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
You can install and start ComfyUI using comfy-cli:
```bash
pip install comfy-cli
comfy install
```
## Manual Install (Windows, Linux)
python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet.
Note that some dependencies do not yet support python 3.13 so using 3.12 is recommended.
Git clone this repo.
@@ -179,11 +151,11 @@ Put your VAE in: models/vae
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2.4```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
This is the command to install the nightly with ROCm 6.3 which might have some performance improvements:
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2.4```
### Intel GPUs (Windows and Linux)
@@ -213,7 +185,7 @@ Additional discussion and help can be found [here](https://github.com/comfyanony
Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
This is the command to install pytorch nightly instead which might have performance improvements:
@@ -261,13 +233,6 @@ For models compatible with Ascend Extension for PyTorch (torch_npu). To get star
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
#### Cambricon MLUs
For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a step-by-step guide tailored to your platform and installation method:
1. Install the Cambricon CNToolkit by adhering to the platform-specific instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cntoolkit_3.7.2/cntoolkit_install_3.7.2/index.html)
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
3. Launch ComfyUI by running `python main.py`
# Running
@@ -324,8 +289,6 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
## Support and dev channel
[Discord](https://comfy.org/discord): Try the #help or #feedback channels.
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
See also: [https://www.comfy.org/](https://www.comfy.org/)
@@ -342,7 +305,7 @@ For any bugs, issues, or feature requests related to the frontend, please use th
The new frontend is now the default for ComfyUI. However, please note:
1. The frontend in the main ComfyUI repository is updated fortnightly.
1. The frontend in the main ComfyUI repository is updated weekly.
2. Daily releases are available in the separate frontend repository.
To use the most up-to-date frontend version:
@@ -359,7 +322,7 @@ To use the most up-to-date frontend version:
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
```
This approach allows you to easily switch between the stable fortnightly release and the cutting-edge daily updates, or even specific versions for testing purposes.
This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
### Accessing the Legacy Frontend

119
alembic.ini Normal file
View File

@@ -0,0 +1,119 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = alembic_db
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic_db/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = sqlite:///user/comfyui.db
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

3
alembic_db/README.md Normal file
View File

@@ -0,0 +1,3 @@
## Generate new revision
1. Update models in `/app/database/models.py`
2. Run `alembic revision --autogenerate -m "{your message}"`

75
alembic_db/env.py Normal file
View File

@@ -0,0 +1,75 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
from app.database.models import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

28
alembic_db/script.py.mako Normal file
View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,58 @@
"""init
Revision ID: 2fb22c4fff36
Revises:
Create Date: 2025-03-27 19:00:47.686079
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '2fb22c4fff36'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('model',
sa.Column('type', sa.Text(), nullable=False),
sa.Column('path', sa.Text(), nullable=False),
sa.Column('title', sa.Text(), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('architecture', sa.Text(), nullable=True),
sa.Column('hash', sa.Text(), nullable=True),
sa.Column('source_url', sa.Text(), nullable=True),
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
sa.PrimaryKeyConstraint('type', 'path')
)
op.create_table('tag',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name')
)
op.create_table('model_tag',
sa.Column('model_type', sa.Text(), nullable=False),
sa.Column('model_path', sa.Text(), nullable=False),
sa.Column('tag_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['model_type', 'model_path'], ['model.type', 'model.path'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['tag_id'], ['tag.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('model_type', 'model_path', 'tag_id')
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('model_tag')
op.drop_table('tag')
op.drop_table('model')
# ### end Alembic commands ###

View File

@@ -1,9 +1,9 @@
from aiohttp import web
from typing import Optional
from folder_paths import folder_names_and_paths, get_directory_by_type
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
from api_server.services.file_service import FileService
from api_server.services.terminal_service import TerminalService
import app.logger
import os
class InternalRoutes:
'''
@@ -15,10 +15,26 @@ class InternalRoutes:
def __init__(self, prompt_server):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
self.file_service = FileService({
"models": models_dir,
"user": user_directory,
"output": output_directory
})
self.prompt_server = prompt_server
self.terminal_service = TerminalService(prompt_server)
def setup_routes(self):
@self.routes.get('/files')
async def list_files(request):
directory_key = request.query.get('directory', '')
try:
file_list = self.file_service.list_files(directory_key)
return web.json_response({"files": file_list})
except ValueError as e:
return web.json_response({"error": str(e)}, status=400)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
@self.routes.get('/logs')
async def get_logs(request):
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
@@ -51,20 +67,6 @@ class InternalRoutes:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)
@self.routes.get('/files/{directory_type}')
async def get_files(request: web.Request) -> web.Response:
directory_type = request.match_info['directory_type']
if directory_type not in ("output", "input", "temp"):
return web.json_response({"error": "Invalid directory type"}, status=400)
directory = get_directory_by_type(directory_type)
sorted_files = sorted(
(entry for entry in os.scandir(directory) if entry.is_file()),
key=lambda entry: -entry.stat().st_mtime
)
return web.json_response([entry.name for entry in sorted_files], status=200)
def get_app(self):
if self._app is None:
self._app = web.Application()

View File

@@ -0,0 +1,13 @@
from typing import Dict, List, Optional
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
class FileService:
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
self.allowed_directories: Dict[str, str] = allowed_directories
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
def list_files(self, directory_key: str) -> List[FileSystemItem]:
if directory_key not in self.allowed_directories:
raise ValueError("Invalid directory key")
directory_path: str = self.allowed_directories[directory_key]
return self.file_system_ops.walk_directory(directory_path)

View File

@@ -4,93 +4,12 @@ import os
import folder_paths
import glob
from aiohttp import web
import json
import logging
from functools import lru_cache
from utils.json_util import merge_json_recursive
# Extra locale files to load into main.json
EXTRA_LOCALE_FILES = [
"nodeDefs.json",
"commands.json",
"settings.json",
]
def safe_load_json_file(file_path: str) -> dict:
if not os.path.exists(file_path):
return {}
try:
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError:
logging.error(f"Error loading {file_path}")
return {}
class CustomNodeManager:
@lru_cache(maxsize=1)
def build_translations(self):
"""Load all custom nodes translations during initialization. Translations are
expected to be loaded from `locales/` folder.
The folder structure is expected to be the following:
- custom_nodes/
- custom_node_1/
- locales/
- en/
- main.json
- commands.json
- settings.json
returned translations are expected to be in the following format:
{
"en": {
"nodeDefs": {...},
"commands": {...},
"settings": {...},
...{other main.json keys}
}
}
"""
translations = {}
for folder in folder_paths.get_folder_paths("custom_nodes"):
# Sort glob results for deterministic ordering
for custom_node_dir in sorted(glob.glob(os.path.join(folder, "*/"))):
locales_dir = os.path.join(custom_node_dir, "locales")
if not os.path.exists(locales_dir):
continue
for lang_dir in glob.glob(os.path.join(locales_dir, "*/")):
lang_code = os.path.basename(os.path.dirname(lang_dir))
if lang_code not in translations:
translations[lang_code] = {}
# Load main.json
main_file = os.path.join(lang_dir, "main.json")
node_translations = safe_load_json_file(main_file)
# Load extra locale files
for extra_file in EXTRA_LOCALE_FILES:
extra_file_path = os.path.join(lang_dir, extra_file)
key = extra_file.split(".")[0]
json_data = safe_load_json_file(extra_file_path)
if json_data:
node_translations[key] = json_data
if node_translations:
translations[lang_code] = merge_json_recursive(
translations[lang_code], node_translations
)
return translations
"""
Placeholder to refactor the custom node management features from ComfyUI-Manager.
Currently it only contains the custom workflow templates feature.
"""
def add_routes(self, routes, webapp, loadedModules):
@routes.get("/workflow_templates")
@@ -99,36 +18,17 @@ class CustomNodeManager:
files = [
file
for folder in folder_paths.get_folder_paths("custom_nodes")
for file in glob.glob(
os.path.join(folder, "*/example_workflows/*.json")
)
for file in glob.glob(os.path.join(folder, '*/example_workflows/*.json'))
]
workflow_templates_dict = (
{}
) # custom_nodes folder name -> example workflow names
workflow_templates_dict = {} # custom_nodes folder name -> example workflow names
for file in files:
custom_nodes_name = os.path.basename(
os.path.dirname(os.path.dirname(file))
)
custom_nodes_name = os.path.basename(os.path.dirname(os.path.dirname(file)))
workflow_name = os.path.splitext(os.path.basename(file))[0]
workflow_templates_dict.setdefault(custom_nodes_name, []).append(
workflow_name
)
workflow_templates_dict.setdefault(custom_nodes_name, []).append(workflow_name)
return web.json_response(workflow_templates_dict)
# Serve workflow templates from custom nodes.
for module_name, module_dir in loadedModules:
workflows_dir = os.path.join(module_dir, "example_workflows")
workflows_dir = os.path.join(module_dir, 'example_workflows')
if os.path.exists(workflows_dir):
webapp.add_routes(
[
web.static(
"/api/workflow_templates/" + module_name, workflows_dir
)
]
)
@routes.get("/i18n")
async def get_i18n(request):
"""Returns translations from all custom nodes' locales folders."""
return web.json_response(self.build_translations())
webapp.add_routes([web.static('/api/workflow_templates/' + module_name, workflows_dir)])

118
app/database/db.py Normal file
View File

@@ -0,0 +1,118 @@
import logging
import os
import shutil
import sys
from app.database.models import Tag
from comfy.cli_args import args
try:
import alembic
import sqlalchemy
except ImportError as e:
req_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../..", "requirements.txt")
)
logging.error(
f"\n\n********** ERROR ***********\n\nRequirements are not installed ({e}). Please install the requirements.txt file by running:\n{sys.executable} -s -m pip install -r {req_path}\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n"
)
exit(-1)
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
Session = None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
return config
def get_db_path():
url = args.database_url
if url.startswith("sqlite:///"):
return url.split("///")[1]
else:
raise ValueError(f"Unsupported database URL '{url}'.")
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
config = get_alembic_config()
# Check if we need to upgrade
engine = create_engine(db_url)
conn = engine.connect()
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(config)
target_rev = script.get_current_head()
if current_rev != target_rev:
# Backup the database pre upgrade
db_path = get_db_path()
backup_path = db_path + ".bkp"
if os.path.exists(db_path):
shutil.copy(db_path, backup_path)
else:
backup_path = None
try:
command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e:
if backup_path:
# Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path)
os.remove(backup_path)
logging.error(f"Error upgrading database: {e}")
raise e
global Session
Session = sessionmaker(bind=engine)
if not current_rev:
# Init db, populate models
from app.model_processor import model_processor
session = create_session()
model_processor.populate_models(session)
# populate tags
tags = (
"character",
"style",
"concept",
"clothing",
"pose",
"background",
"vehicle",
"object",
"animal",
"action",
)
for tag in tags:
session.add(Tag(name=tag))
session.commit()
def can_create_session():
return Session is not None
def create_session():
return Session()

76
app/database/models.py Normal file
View File

@@ -0,0 +1,76 @@
from sqlalchemy import (
Column,
Integer,
Text,
DateTime,
Table,
ForeignKeyConstraint,
)
from sqlalchemy.orm import relationship, declarative_base
from sqlalchemy.sql import func
Base = declarative_base()
def to_dict(obj):
fields = obj.__table__.columns.keys()
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
ModelTag = Table(
"model_tag",
Base.metadata,
Column(
"model_type",
Text,
primary_key=True,
),
Column(
"model_path",
Text,
primary_key=True,
),
Column("tag_id", Integer, primary_key=True),
ForeignKeyConstraint(
["model_type", "model_path"], ["model.type", "model.path"], ondelete="CASCADE"
),
ForeignKeyConstraint(["tag_id"], ["tag.id"], ondelete="CASCADE"),
)
class Model(Base):
__tablename__ = "model"
type = Column(Text, primary_key=True)
path = Column(Text, primary_key=True)
title = Column(Text)
description = Column(Text)
architecture = Column(Text)
hash = Column(Text)
source_url = Column(Text)
date_added = Column(DateTime, server_default=func.now())
# Relationship with tags
tags = relationship("Tag", secondary=ModelTag, back_populates="models")
def to_dict(self):
dict = to_dict(self)
dict["tags"] = [tag.to_dict() for tag in self.tags]
return dict
class Tag(Base):
__tablename__ = "tag"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Text, nullable=False, unique=True)
# Relationship with models
models = relationship("Model", secondary=ModelTag, back_populates="tags")
def to_dict(self):
return to_dict(self)

View File

@@ -3,10 +3,8 @@ import argparse
import logging
import os
import re
import sys
import tempfile
import zipfile
import importlib
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
@@ -14,19 +12,9 @@ from typing import TypedDict, Optional
import requests
from typing_extensions import NotRequired
from comfy.cli_args import DEFAULT_VERSION_STRING
try:
import comfyui_frontend_package
except ImportError:
# TODO: Remove the check after roll out of 0.3.16
req_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. Please install the updated requirements.txt file by running:\n{sys.executable} -m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n")
exit(-1)
REQUEST_TIMEOUT = 10 # seconds
@@ -121,7 +109,7 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
DEFAULT_FRONTEND_PATH = str(importlib.resources.files(comfyui_frontend_package) / "static")
DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod

View File

@@ -1,19 +1,30 @@
from __future__ import annotations
import os
import base64
import json
import time
import logging
from app.database.db import create_session
import folder_paths
import glob
import comfy.utils
from aiohttp import web
from PIL import Image
from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
from folder_paths import map_legacy, filter_files_extensions, get_full_path
from app.database.models import Tag, Model
from app.model_processor import get_model_previews, model_processor
from utils.web import dumps
from sqlalchemy.orm import joinedload
import sqlalchemy.exc
def bad_request(message: str):
return web.json_response({"error": message}, status=400)
def missing_field(field: str):
return bad_request(f"{field} is required")
def not_found(message: str):
return web.json_response({"error": message + " not found"}, status=404)
class ModelFileManager:
def __init__(self) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
@@ -62,7 +73,7 @@ class ModelFileManager:
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)
previews = self.get_model_previews(full_filename)
previews = get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)
@@ -76,6 +87,183 @@ class ModelFileManager:
except:
return web.Response(status=404)
@routes.get("/v2/models")
async def get_models(request):
with create_session() as session:
model_path = request.query.get("path", None)
model_type = request.query.get("type", None)
query = session.query(Model).options(joinedload(Model.tags))
if model_path:
query = query.filter(Model.path == model_path)
if model_type:
query = query.filter(Model.type == model_type)
models = query.all()
if model_path and model_type:
if len(models) == 0:
return not_found("Model")
return web.json_response(models[0].to_dict(), dumps=dumps)
return web.json_response([model.to_dict() for model in models], dumps=dumps)
@routes.post("/v2/models")
async def add_model(request):
with create_session() as session:
data = await request.json()
model_type = data.get("type", None)
model_path = data.get("path", None)
if not model_type:
return missing_field("type")
if not model_path:
return missing_field("path")
tags = data.pop("tags", [])
fields = Model.metadata.tables["model"].columns.keys()
# Validate keys are valid model fields
for key in data.keys():
if key not in fields:
return bad_request(f"Invalid field: {key}")
# Validate file exists
if not get_full_path(model_type, model_path):
return not_found(f"File '{model_type}/{model_path}'")
model = Model()
for field in fields:
if field in data:
setattr(model, field, data[field])
model.tags = session.query(Tag).filter(Tag.id.in_(tags)).all()
for tag in tags:
if tag not in [t.id for t in model.tags]:
return not_found(f"Tag '{tag}'")
try:
session.add(model)
session.commit()
except sqlalchemy.exc.IntegrityError as e:
session.rollback()
return bad_request(e.orig.args[0])
model_processor.run()
return web.json_response(model.to_dict(), dumps=dumps)
@routes.delete("/v2/models")
async def delete_model(request):
with create_session() as session:
model_path = request.query.get("path", None)
model_type = request.query.get("type", None)
if not model_path:
return missing_field("path")
if not model_type:
return missing_field("type")
full_path = get_full_path(model_type, model_path)
if full_path:
return bad_request("Model file exists, please delete the file before deleting the model record.")
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
if not model:
return not_found("Model")
session.delete(model)
session.commit()
return web.Response(status=204)
@routes.get("/v2/tags")
async def get_tags(request):
with create_session() as session:
tags = session.query(Tag).all()
return web.json_response(
[{"id": tag.id, "name": tag.name} for tag in tags]
)
@routes.post("/v2/tags")
async def create_tag(request):
with create_session() as session:
data = await request.json()
name = data.get("name", None)
if not name:
return missing_field("name")
tag = Tag(name=name)
session.add(tag)
session.commit()
return web.json_response({"id": tag.id, "name": tag.name})
@routes.delete("/v2/tags")
async def delete_tag(request):
with create_session() as session:
tag_id = request.query.get("id", None)
if not tag_id:
return missing_field("id")
tag = session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
return not_found("Tag")
session.delete(tag)
session.commit()
return web.Response(status=204)
@routes.post("/v2/models/tags")
async def add_model_tag(request):
with create_session() as session:
data = await request.json()
tag_id = data.get("tag", None)
model_path = data.get("path", None)
model_type = data.get("type", None)
if tag_id is None:
return missing_field("tag")
if model_path is None:
return missing_field("path")
if model_type is None:
return missing_field("type")
try:
tag_id = int(tag_id)
except ValueError:
return bad_request("Invalid tag id")
tag = session.query(Tag).filter(Tag.id == tag_id).first()
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
if not model:
return not_found("Model")
model.tags.append(tag)
session.commit()
return web.json_response(model.to_dict(), dumps=dumps)
@routes.delete("/v2/models/tags")
async def delete_model_tag(request):
with create_session() as session:
tag_id = request.query.get("tag", None)
model_path = request.query.get("path", None)
model_type = request.query.get("type", None)
if tag_id is None:
return missing_field("tag")
if model_path is None:
return missing_field("path")
if model_type is None:
return missing_field("type")
try:
tag_id = int(tag_id)
except ValueError:
return bad_request("Invalid tag id")
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
if not model:
return not_found("Model")
model.tags = [tag for tag in model.tags if tag.id != tag_id]
session.commit()
return web.Response(status=204)
@routes.get("/v2/models/missing")
async def get_missing_models(request):
return web.json_response(model_processor.missing_models)
def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name]
@@ -146,39 +334,5 @@ class ModelFileManager:
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
safetensors_metadata = {}
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
def __exit__(self, exc_type, exc_value, traceback):
self.clear_cache()

263
app/model_processor.py Normal file
View File

@@ -0,0 +1,263 @@
import base64
from datetime import datetime
import glob
import hashlib
from io import BytesIO
import json
import logging
import os
import threading
import time
import comfy.utils
from app.database.models import Model
from app.database.db import create_session
from comfy.cli_args import args
from folder_paths import (
filter_files_content_types,
get_full_path,
folder_names_and_paths,
get_filename_list,
)
from PIL import Image
from urllib import request
def get_model_previews(
filepath: str, check_metadata: bool = True
) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if not check_metadata:
return result
safetensors_file = next(
filter(lambda x: x.endswith(".safetensors"), match_files), None
)
safetensors_metadata = {}
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(
safetensors_filepath, max_size=8 * 1024 * 1024
)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get(
"ssmd_cover_images", None
)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
class ModelProcessor:
def __init__(self):
self._thread = None
self._lock = threading.Lock()
self._run = False
self.missing_models = []
def run(self):
if args.disable_model_processing:
return
if self._thread is None:
# Lock to prevent multiple threads from starting
with self._lock:
self._run = True
if self._thread is None:
self._thread = threading.Thread(target=self._process_models)
self._thread.daemon = True
self._thread.start()
def populate_models(self, session):
# Ensure database state matches filesystem
existing_models = session.query(Model).all()
for folder_name in folder_names_and_paths.keys():
if folder_name == "custom_nodes" or folder_name == "configs":
continue
seen = set()
files = get_filename_list(folder_name)
for file in files:
if file in seen:
logging.warning(f"Skipping duplicate named model: {file}")
continue
seen.add(file)
existing_model = None
for model in existing_models:
if model.path == file and model.type == folder_name:
existing_model = model
break
if existing_model:
# Model already exists in db, remove from list and skip
existing_models.remove(existing_model)
continue
file_path = get_full_path(folder_name, file)
model = Model(
path=file,
type=folder_name,
date_added=datetime.fromtimestamp(os.path.getctime(file_path)),
)
session.add(model)
for model in existing_models:
if not get_full_path(model.type, model.path):
logging.warning(f"Model {model.path} not found")
self.missing_models.append({"type": model.type, "path": model.path})
session.commit()
def _get_models(self, session):
models = session.query(Model).filter(Model.hash == None).all()
return models
def _process_file(self, model_path):
is_safetensors = model_path.endswith(".safetensors")
metadata = {}
h = hashlib.sha256()
with open(model_path, "rb", buffering=0) as f:
if is_safetensors:
# Read header length (8 bytes)
header_size_bytes = f.read(8)
header_len = int.from_bytes(header_size_bytes, "little")
h.update(header_size_bytes)
# Read header
header_bytes = f.read(header_len)
h.update(header_bytes)
try:
metadata = json.loads(header_bytes)
except json.JSONDecodeError:
pass
# Read rest of file
b = bytearray(128 * 1024)
mv = memoryview(b)
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest(), metadata
def _populate_info(self, model, metadata):
model.title = metadata.get("modelspec.title", None)
model.description = metadata.get("modelspec.description", None)
model.architecture = metadata.get("modelspec.architecture", None)
def _extract_image(self, model_path, metadata):
# check if image already exists
if len(get_model_previews(model_path, check_metadata=False)) > 0:
return
image_path = os.path.splitext(model_path)[0] + ".webp"
if os.path.exists(image_path):
return
cover_images = metadata.get("ssmd_cover_images", None)
image = None
if cover_images:
try:
cover_images = json.loads(cover_images)
if len(cover_images) > 0:
image_data = cover_images[0]
image = Image.open(BytesIO(base64.b64decode(image_data)))
except Exception as e:
logging.warning(
f"Error extracting cover image for model {model_path}: {e}"
)
if not image:
thumbnail = metadata.get("modelspec.thumbnail", None)
if thumbnail:
try:
response = request.urlopen(thumbnail)
image = Image.open(response)
except Exception as e:
logging.warning(
f"Error extracting thumbnail for model {model_path}: {e}"
)
if image:
image.thumbnail((512, 512))
image.save(image_path)
image.close()
def _process_models(self):
with create_session() as session:
checked = set()
self.populate_models(session)
while self._run:
self._run = False
models = self._get_models(session)
if len(models) == 0:
break
for model in models:
# prevent looping on the same model if it crashes
if model.path in checked:
continue
checked.add(model.path)
try:
time.sleep(0)
now = time.time()
model_path = get_full_path(model.type, model.path)
if not model_path:
logging.warning(f"Model {model.path} not found")
self.missing_models.append(model.path)
continue
logging.debug(f"Processing model {model_path}")
hash, header = self._process_file(model_path)
logging.debug(
f"Processed model {model_path} in {time.time() - now} seconds"
)
model.hash = hash
if header:
metadata = header.get("__metadata__", None)
if metadata:
self._populate_info(model, metadata)
self._extract_image(model_path, metadata)
session.commit()
except Exception as e:
logging.error(f"Error processing model {model.path}: {e}")
with self._lock:
self._thread = None
model_processor = ModelProcessor()

View File

@@ -43,11 +43,10 @@ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certific
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
@@ -130,12 +129,7 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult")
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@@ -182,9 +176,13 @@ parser.add_argument(
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
)
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
if comfy.options.args_parsing:
args = parser.parse_args()
@@ -196,17 +194,3 @@ if args.windows_standalone_build:
if args.disable_auto_launch:
args.auto_launch = False
if args.force_fp16:
args.fp16_unet = True
# '--fast' is not provided, use an empty set
if args.fast is None:
args.fast = set()
# '--fast' is provided with an empty list, enable all optimizations
elif args.fast == []:
args.fast = set(PerformanceFeature)
# '--fast' is provided with a list of performance features, use that list
else:
args.fast = set(args.fast)

View File

@@ -102,10 +102,9 @@ class CLIPTextModel_(torch.nn.Module):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1)
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:

View File

@@ -66,26 +66,13 @@ class IO(StrEnum):
b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b))
class RemoteInputOptions(TypedDict):
route: str
"""The route to the remote source."""
refresh_button: bool
"""Specifies whether to show a refresh button in the UI below the widget."""
control_after_refresh: Literal["first", "last"]
"""Specifies the control after the refresh button is clicked. If "first", the first item will be automatically selected, and so on."""
timeout: int
"""The maximum amount of time to wait for a response from the remote source in milliseconds."""
max_retries: int
"""The maximum number of retries before aborting the request."""
refresh: int
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
class InputTypeOptions(TypedDict):
"""Provides type hinting for the return type of the INPUT_TYPES node function.
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes
Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes
"""
default: bool | str | float | int | list | tuple
@@ -126,14 +113,6 @@ class InputTypeOptions(TypedDict):
# defaultVal: str
dynamicPrompts: bool
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
# class InputTypeCombo(InputTypeOptions):
image_upload: bool
"""Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`."""
image_folder: Literal["input", "output", "temp"]
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
"""
remote: RemoteInputOptions
"""Specifies the configuration for a remote input."""
class HiddenInputTypeDict(TypedDict):
@@ -154,7 +133,7 @@ class HiddenInputTypeDict(TypedDict):
class InputTypeDict(TypedDict):
"""Provides type hinting for node INPUT_TYPES.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs
"""
required: dict[str, tuple[IO, InputTypeOptions]]
@@ -164,14 +143,14 @@ class InputTypeDict(TypedDict):
hidden: HiddenInputTypeDict
"""Offers advanced functionality and server-client communication.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
"""
class ComfyNodeABC(ABC):
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview
"""
DESCRIPTION: str
@@ -188,7 +167,7 @@ class ComfyNodeABC(ABC):
CATEGORY: str
"""The category of the node, as per the "Add Node" menu.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#category
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category
"""
EXPERIMENTAL: bool
"""Flags a node as experimental, informing users that it may change or not work as expected."""
@@ -202,9 +181,9 @@ class ComfyNodeABC(ABC):
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
* The ``optional`` key can be added to describe inputs which do not need to be connected.
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#input-types
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types
"""
return {"required": {}}
@@ -219,7 +198,7 @@ class ComfyNodeABC(ABC):
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node
"""
INPUT_IS_LIST: bool
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
@@ -230,7 +209,7 @@ class ComfyNodeABC(ABC):
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
"""
OUTPUT_IS_LIST: tuple[bool]
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
@@ -248,7 +227,7 @@ class ComfyNodeABC(ABC):
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
specifying which outputs which should be so treated.
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
"""
RETURN_TYPES: tuple[IO]
@@ -258,19 +237,19 @@ class ComfyNodeABC(ABC):
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types
"""
RETURN_NAMES: tuple[str]
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names
"""
OUTPUT_TOOLTIPS: tuple[str]
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
FUNCTION: str
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#function
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function
"""
@@ -288,7 +267,7 @@ class CheckLazyMixin:
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
"""
need = [name for name in kwargs if kwargs[name] is None]

View File

@@ -3,6 +3,9 @@ import math
import comfy.utils
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
class CONDRegular:
def __init__(self, cond):
self.cond = cond
@@ -43,7 +46,7 @@ class CONDCrossAttn(CONDRegular):
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = math.lcm(s1[1], s2[1])
mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
@@ -54,7 +57,7 @@ class CONDCrossAttn(CONDRegular):
crossattn_max_len = self.cond.shape[1]
for x in others:
c = x.cond
crossattn_max_len = math.lcm(crossattn_max_len, c.shape[1])
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
conds.append(c)
out = []

View File

@@ -418,7 +418,10 @@ def controlnet_config(sd, model_options={}):
weight_dtype = comfy.utils.weight_dtype(sd)
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
@@ -686,7 +689,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
if supported_inference_dtypes is None:
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()

View File

@@ -4,6 +4,105 @@ import logging
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
# =================#
# UNet Conversion #
# =================#
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("out.0.weight", "conv_norm_out.weight"),
("out.0.bias", "conv_norm_out.bias"),
("out.2.weight", "conv_out.weight"),
("out.2.bias", "conv_out.bias"),
]
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0", "norm1"),
("in_layers.2", "conv1"),
("out_layers.0", "norm2"),
("out_layers.3", "conv2"),
("emb_layers.1", "time_emb_proj"),
("skip_connection", "conv_shortcut"),
]
unet_conversion_map_layer = []
# hardcoded number of downblocks and resnets/attentions...
# would need smarter logic for other networks.
for i in range(4):
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
def convert_unet_state_dict(unet_state_dict):
# buyer beware: this is a *brittle* function,
# and correct output requires that all of these pieces interact in
# the exact order in which I have arranged them.
mapping = {k: k for k in unet_state_dict.keys()}
for sd_name, hf_name in unet_conversion_map:
mapping[hf_name] = sd_name
for k, v in mapping.items():
if "resnets" in k:
for sd_part, hf_part in unet_conversion_map_resnet:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, hf_part in unet_conversion_map_layer:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
# ================#
# VAE Conversion #
# ================#
@@ -114,7 +213,6 @@ textenc_pattern = re.compile("|".join(protected.keys()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx = {"q": 0, "k": 1, "v": 2}
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
def cat_tensors(tensors):
x = 0
@@ -131,7 +229,6 @@ def cat_tensors(tensors):
return out
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
new_state_dict = {}
capture_qkv_weight = {}
@@ -187,3 +284,5 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
def convert_text_enc_state_dict(text_enc_dict):
return text_enc_dict

View File

@@ -661,7 +661,7 @@ class UniPC:
if x_t is None:
if use_predictor:
pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_p, D1s)
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
@@ -669,7 +669,7 @@ class UniPC:
if use_corrector:
model_t = self.model_fn(x_t, t)
if D1s is not None:
corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = (model_t - model_prev_0)

View File

@@ -40,7 +40,7 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule."""
t = torch.linspace(1, eps_s, n, device=device)
sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t))
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
return append_zero(sigmas)
@@ -1267,7 +1267,7 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
return x
@torch.no_grad()
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, eta=1., cfg_pp=False):
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None, cfg_pp=False):
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@@ -1289,80 +1289,50 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if s_churn > 0:
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
sigma_hat = sigmas[i] * (gamma + 1)
else:
gamma = 0
sigma_hat = sigmas[i]
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
if sigma_down == 0 or old_denoised is None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
if sigmas[i + 1] == 0 or old_denoised is None:
# Euler method
if cfg_pp:
d = to_d(x, sigmas[i], uncond_denoised)
x = denoised + d * sigma_down
d = to_d(x, sigma_hat, uncond_denoised)
x = denoised + d * sigmas[i + 1]
else:
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
d = to_d(x, sigma_hat, denoised)
dt = sigmas[i + 1] - sigma_hat
x = x + d * dt
else:
# Second order multistep method in https://arxiv.org/pdf/2308.02157
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1])
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigmas[i + 1]), t_fn(sigmas[i - 1])
h = t_next - t
c2 = (t_prev - t) / h
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
b2 = torch.nan_to_num(phi2_val / c2, nan=0.0)
b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)
if cfg_pp:
x = x + (denoised - uncond_denoised)
x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised)
else:
x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised)
# Noise addition
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
x = (sigma_fn(t_next) / sigma_fn(t)) * x + h * (b1 * denoised + b2 * old_denoised)
if cfg_pp:
old_denoised = uncond_denoised
else:
old_denoised = denoised
old_denoised = denoised
return x
@torch.no_grad()
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False)
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=False)
@torch.no_grad()
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=True)
@torch.no_grad()
def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=False)
@torch.no_grad()
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
@torch.no_grad()
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_d = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
dt = sigmas[i + 1] - sigmas[i]
if i == 0:
# Euler method
x = x + d * dt
else:
# Gradient estimation
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
x = x + d_bar * dt
old_d = d
return x
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True)

View File

@@ -407,52 +407,3 @@ class Cosmos1CV8x8x8(LatentFormat):
]
latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]
class Wan21(LatentFormat):
latent_channels = 16
latent_dimensions = 3
latent_rgb_factors = [
[-0.1299, -0.1692, 0.2932],
[ 0.0671, 0.0406, 0.0442],
[ 0.3568, 0.2548, 0.1747],
[ 0.0372, 0.2344, 0.1420],
[ 0.0313, 0.0189, -0.0328],
[ 0.0296, -0.0956, -0.0665],
[-0.3477, -0.4059, -0.2925],
[ 0.0166, 0.1902, 0.1975],
[-0.0412, 0.0267, -0.1364],
[-0.1293, 0.0740, 0.1636],
[ 0.0680, 0.3019, 0.1128],
[ 0.0032, 0.0581, 0.0639],
[-0.1251, 0.0927, 0.1699],
[ 0.0060, -0.0633, 0.0005],
[ 0.3477, 0.2275, 0.2950],
[ 0.1984, 0.0913, 0.1861]
]
latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360]
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]).view(1, self.latent_channels, 1, 1, 1)
self.taesd_decoder_name = None #TODO
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std
def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean

View File

@@ -168,19 +168,15 @@ class Attention(nn.Module):
k = self.to_k[1](k)
v = self.to_v[1](v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
# apply_rotary_pos_emb inlined
q_shape = q.shape
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
# apply_rotary_pos_emb inlined
k_shape = k.shape
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
q = apply_rotary_pos_emb(q, rope_emb)
k = apply_rotary_pos_emb(k, rope_emb)
return q, k, v
def cal_attn(self, q, k, v, mask=None):
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
def forward(
self,
x,
@@ -195,10 +191,7 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
del q, k, v
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
return self.cal_attn(q, k, v, mask)
class FeedForward(nn.Module):
@@ -795,7 +788,10 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x = x + extra_per_block_pos_emb
for block in self.blocks:
x = block(
x,

View File

@@ -30,8 +30,6 @@ import torch.nn as nn
import torch.nn.functional as F
import logging
from comfy.ldm.modules.diffusionmodules.model import vae_attention
from .patching import (
Patcher,
Patcher3D,
@@ -402,8 +400,6 @@ class CausalAttnBlock(nn.Module):
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.optimized_attention = vae_attention()
def forward(self, x: torch.Tensor) -> torch.Tensor:
h_ = x
h_ = self.norm(h_)
@@ -417,7 +413,18 @@ class CausalAttnBlock(nn.Module):
v, batch_size = time2batch(v)
b, c, h, w = q.shape
h_ = self.optimized_attention(q, k, v)
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = batch2time(h_, batch_size)
h_ = self.proj_out(h_)
@@ -864,16 +871,18 @@ class EncoderFactorized(nn.Module):
x = self.patcher3d(x)
# downsampling
h = self.conv_in(x)
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h)
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h)
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)

View File

@@ -281,76 +281,54 @@ class UnPatcher3D(UnPatcher):
hh = hh.to(dtype=dtype)
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
del x
# Height height transposed convolutions.
xll = F.conv_transpose3d(
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlll
xll += F.conv_transpose3d(
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xllh
xlh = F.conv_transpose3d(
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlhl
xlh += F.conv_transpose3d(
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlhh
xhl = F.conv_transpose3d(
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhll
xhl += F.conv_transpose3d(
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhlh
xhh = F.conv_transpose3d(
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhhl
xhh += F.conv_transpose3d(
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhhh
# Handles width transposed convolutions.
xl = F.conv_transpose3d(
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xll
xl += F.conv_transpose3d(
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xlh
xh = F.conv_transpose3d(
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xhl
xh += F.conv_transpose3d(
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xhh
# Handles time axis transposed convolutions.
x = F.conv_transpose3d(
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)
del xl
x += F.conv_transpose3d(
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)

View File

@@ -168,7 +168,7 @@ class GeneralDIT(nn.Module):
operations=operations,
)
self.build_pos_embed(device=device, dtype=dtype)
self.build_pos_embed(device=device)
self.block_x_format = block_x_format
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
@@ -210,7 +210,7 @@ class GeneralDIT(nn.Module):
operations=operations,
)
def build_pos_embed(self, device=None, dtype=None):
def build_pos_embed(self, device=None):
if self.pos_emb_cls == "rope3d":
cls_type = VideoRopePosition3DEmb
else:
@@ -242,7 +242,6 @@ class GeneralDIT(nn.Module):
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
kwargs["device"] = device
kwargs["dtype"] = dtype
self.extra_pos_embedder = LearnablePosEmbAxis(
**kwargs,
)
@@ -293,7 +292,7 @@ class GeneralDIT(nn.Module):
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb:
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device)
else:
extra_pos_emb = None
@@ -477,8 +476,6 @@ class GeneralDIT(nn.Module):
inputs["original_shape"],
)
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
del inputs
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
assert (
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
@@ -489,8 +486,6 @@ class GeneralDIT(nn.Module):
self.blocks["block0"].x_format == block.x_format
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
x = block(
x,
affline_emb_B_D,
@@ -498,6 +493,7 @@ class GeneralDIT(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
)
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")

View File

@@ -41,12 +41,12 @@ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0)
class VideoPositionEmb(nn.Module):
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
"""
It delegates the embedding generation to generate_embeddings function.
"""
B_T_H_W_C = x_B_T_H_W_C.shape
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device)
return embeddings
@@ -104,7 +104,6 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
w_ntk_factor: Optional[float] = None,
t_ntk_factor: Optional[float] = None,
device=None,
dtype=None,
):
"""
Generate embeddings for the given input size.
@@ -174,7 +173,6 @@ class LearnablePosEmbAxis(VideoPositionEmb):
len_w: int,
len_t: int,
device=None,
dtype=None,
**kwargs,
):
"""
@@ -186,16 +184,17 @@ class LearnablePosEmbAxis(VideoPositionEmb):
self.interpolation = interpolation
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device))
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device))
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device))
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
B, T, H, W, _ = B_T_H_W_C
if self.interpolation == "crop":
emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
emb_h_H = self.pos_emb_h[:H].to(device=device)
emb_w_W = self.pos_emb_w[:W].to(device=device)
emb_t_T = self.pos_emb_t[:T].to(device=device)
emb = (
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)

View File

@@ -18,7 +18,6 @@ import logging
import torch
from torch import nn
from enum import Enum
import math
from .cosmos_tokenizer.layers3d import (
EncoderFactorized,
@@ -90,8 +89,8 @@ class CausalContinuousVideoTokenizer(nn.Module):
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
num_parameters = sum(param.numel() for param in self.parameters())
logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
logging.debug(
logging.info(f"model={self.name}, num_parameters={num_parameters:,}")
logging.info(
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
)
@@ -106,23 +105,17 @@ class CausalContinuousVideoTokenizer(nn.Module):
z, posteriors = self.distribution(moments)
latent_ch = z.shape[1]
latent_t = z.shape[2]
in_dtype = z.dtype
mean = self.latent_mean.view(latent_ch, -1)
std = self.latent_std.view(latent_ch, -1)
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
dtype = z.dtype
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
return ((z - mean) / std) * self.sigma_data
def decode(self, z):
in_dtype = z.dtype
latent_ch = z.shape[1]
latent_t = z.shape[2]
mean = self.latent_mean.view(latent_ch, -1)
std = self.latent_std.view(latent_ch, -1)
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
z = z / self.sigma_data
z = z * std + mean

View File

@@ -230,7 +230,8 @@ class SingleStreamBlock(nn.Module):
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
mod, _ = self.modulation(vec)
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)

View File

@@ -5,15 +5,8 @@ from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q_shape = q.shape
k_shape = k.shape
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
@@ -22,7 +15,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
device = torch.device("cpu")
else:
device = pos.device

View File

@@ -109,8 +109,9 @@ class Flux(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
@@ -185,7 +186,7 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))

View File

@@ -240,8 +240,9 @@ class HunyuanVideo(nn.Module):
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.params.guidance_embed:
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
@@ -310,10 +311,10 @@ class HunyuanVideo(nn.Module):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
img = img.reshape(initial_shape)
return img
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])

View File

@@ -7,7 +7,7 @@ from einops import rearrange
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from .symmetric_patchifier import SymmetricPatchifier
def get_timestep_embedding(
@@ -377,16 +377,12 @@ class LTXVModel(torch.nn.Module):
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.generator = None
self.vae_scale_factors = vae_scale_factors
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
@@ -420,23 +416,42 @@ class LTXVModel(torch.nn.Module):
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2],
orig_height=x.shape[3],
orig_width=x.shape[4],
batch_size=x.shape[0],
scale_grid=((1 / frame_rate) * 8, 32, 32),
device=x.device,
)
if guiding_latent is not None:
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
ts *= input_ts
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
timestep = self.patchifier.patchify(ts)
input_x = x.clone()
x[:, :, 0] = guiding_latent[:, :, 0]
if guiding_latent_noise_scale > 0:
if self.generator is None:
self.generator = torch.Generator(device=x.device).manual_seed(42)
elif self.generator.device != x.device:
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
scale = guiding_latent_noise_scale * (input_ts ** 2)
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
orig_shape = list(x.shape)
x, latent_coords = self.patchifier.patchify(x)
pixel_coords = latent_to_pixel_coords(
latent_coords=latent_coords,
scale_factors=self.vae_scale_factors,
causal_fix=self.causal_temporal_positioning,
)
if keyframe_idxs is not None:
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
fractional_coords = pixel_coords.to(torch.float32)
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
x = self.patchifier.patchify(x)
x = self.patchify_proj(x)
timestep = timestep * 1000.0
@@ -444,7 +459,7 @@ class LTXVModel(torch.nn.Module):
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
@@ -504,4 +519,8 @@ class LTXVModel(torch.nn.Module):
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
)
if guiding_latent is not None:
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
# print("res", x)
return x

View File

@@ -6,29 +6,16 @@ from einops import rearrange
from torch import Tensor
def latent_to_pixel_coords(
latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False
) -> Tensor:
"""
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
configuration.
Args:
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
containing the latent corner coordinates of each token.
scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space.
causal_fix (bool): Whether to take into account the different temporal scale
of the first frame. Default = False for backwards compatibility.
Returns:
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
"""
pixel_coords = (
latent_coords
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
)
if causal_fix:
# Fix temporal scale for first frame to 1 due to causality
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
elif dims_to_append == 0:
return x
return x[(...,) + (None,) * dims_to_append]
class Patchifier(ABC):
@@ -57,26 +44,29 @@ class Patchifier(ABC):
def patch_size(self):
return self._patch_size
def get_latent_coords(
self, latent_num_frames, latent_height, latent_width, batch_size, device
def get_grid(
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
):
"""
Return a tensor of shape [batch_size, 3, num_patches] containing the
top-left corner latent coordinates of each latent patch.
The tensor is repeated for each batch element.
"""
latent_sample_coords = torch.meshgrid(
torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
torch.arange(0, latent_height, self._patch_size[1], device=device),
torch.arange(0, latent_width, self._patch_size[2], device=device),
indexing="ij",
)
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_coords = rearrange(
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
)
return latent_coords
f = orig_num_frames // self._patch_size[0]
h = orig_height // self._patch_size[1]
w = orig_width // self._patch_size[2]
grid_h = torch.arange(h, dtype=torch.float32, device=device)
grid_w = torch.arange(w, dtype=torch.float32, device=device)
grid_f = torch.arange(f, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
if scale_grid is not None:
for i in range(3):
if isinstance(scale_grid[i], Tensor):
scale = append_dims(scale_grid[i], grid.ndim - 1)
else:
scale = scale_grid[i]
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
return grid
class SymmetricPatchifier(Patchifier):
@@ -84,8 +74,6 @@ class SymmetricPatchifier(Patchifier):
self,
latents: Tensor,
) -> Tuple[Tensor, Tensor]:
b, _, f, h, w = latents.shape
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
latents = rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
@@ -93,7 +81,7 @@ class SymmetricPatchifier(Patchifier):
p2=self._patch_size[1],
p3=self._patch_size[2],
)
return latents, latent_coords
return latents
def unpatchify(
self,

View File

@@ -15,7 +15,6 @@ class CausalConv3d(nn.Module):
stride: Union[int, Tuple[int]] = 1,
dilation: int = 1,
groups: int = 1,
spatial_padding_mode: str = "zeros",
**kwargs,
):
super().__init__()
@@ -39,7 +38,7 @@ class CausalConv3d(nn.Module):
stride=stride,
dilation=dilation,
padding=padding,
padding_mode=spatial_padding_mode,
padding_mode="zeros",
groups=groups,
)

View File

@@ -1,15 +1,13 @@
from __future__ import annotations
import torch
from torch import nn
from functools import partial
import math
from einops import rearrange
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
ops = comfy.ops.disable_weight_init
class Encoder(nn.Module):
@@ -34,7 +32,7 @@ class Encoder(nn.Module):
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var (`str`, *optional*, defaults to `per_channel`):
The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
"""
def __init__(
@@ -42,13 +40,12 @@ class Encoder(nn.Module):
dims: Union[int, Tuple[int, int]] = 3,
in_channels: int = 3,
out_channels: int = 3,
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
blocks=[("res_x", 1)],
base_channels: int = 128,
norm_num_groups: int = 32,
patch_size: Union[int, Tuple[int]] = 1,
norm_layer: str = "group_norm", # group_norm, pixel_norm
latent_log_var: str = "per_channel",
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.patch_size = patch_size
@@ -68,7 +65,6 @@ class Encoder(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.down_blocks = nn.ModuleList([])
@@ -86,7 +82,6 @@ class Encoder(nn.Module):
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
@@ -97,7 +92,6 @@ class Encoder(nn.Module):
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = make_conv_nd(
@@ -107,7 +101,6 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(2, 1, 1),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = make_conv_nd(
@@ -117,7 +110,6 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(1, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
block = make_conv_nd(
@@ -127,7 +119,6 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(2, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
@@ -138,34 +129,6 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(2, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(2, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown block: {block_name}")
@@ -189,18 +152,10 @@ class Encoder(nn.Module):
conv_out_channels *= 2
elif latent_log_var == "uniform":
conv_out_channels += 1
elif latent_log_var == "constant":
conv_out_channels += 1
elif latent_log_var != "none":
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd(
dims,
output_channel,
conv_out_channels,
3,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
)
self.gradient_checkpointing = False
@@ -242,15 +197,6 @@ class Encoder(nn.Module):
sample = torch.cat([sample, repeated_last_channel], dim=1)
else:
raise ValueError(f"Invalid input shape: {sample.shape}")
elif self.latent_log_var == "constant":
sample = sample[:, :-1, ...]
approx_ln_0 = (
-30
) # this is the minimal clamp value in DiagonalGaussianDistribution objects
sample = torch.cat(
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
dim=1,
)
return sample
@@ -285,7 +231,7 @@ class Decoder(nn.Module):
dims,
in_channels: int = 3,
out_channels: int = 3,
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
blocks=[("res_x", 1)],
base_channels: int = 128,
layers_per_block: int = 2,
norm_num_groups: int = 32,
@@ -293,7 +239,6 @@ class Decoder(nn.Module):
norm_layer: str = "group_norm",
causal: bool = True,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.patch_size = patch_size
@@ -319,7 +264,6 @@ class Decoder(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.up_blocks = nn.ModuleList([])
@@ -339,7 +283,6 @@ class Decoder(nn.Module):
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
@@ -351,7 +294,6 @@ class Decoder(nn.Module):
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
@@ -364,21 +306,14 @@ class Decoder(nn.Module):
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=False,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
)
elif block_name == "compress_all":
output_channel = output_channel // block_params.get("multiplier", 1)
@@ -388,7 +323,6 @@ class Decoder(nn.Module):
stride=(2, 2, 2),
residual=block_params.get("residual", False),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown layer: {block_name}")
@@ -406,13 +340,7 @@ class Decoder(nn.Module):
self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd(
dims,
output_channel,
out_channels,
3,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
dims, output_channel, out_channels, 3, padding=1, causal=True
)
self.gradient_checkpointing = False
@@ -505,12 +433,6 @@ class UNetMidBlock3D(nn.Module):
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
inject_noise (`bool`, *optional*, defaults to `False`):
Whether to inject noise into the hidden states.
timestep_conditioning (`bool`, *optional*, defaults to `False`):
Whether to condition the hidden states on the timestep.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
@@ -529,7 +451,6 @@ class UNetMidBlock3D(nn.Module):
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
resnet_groups = (
@@ -555,17 +476,13 @@ class UNetMidBlock3D(nn.Module):
norm_layer=norm_layer,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
for _ in range(num_layers)
]
)
def forward(
self,
hidden_states: torch.FloatTensor,
causal: bool = True,
timestep: Optional[torch.Tensor] = None,
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
) -> torch.FloatTensor:
timestep_embed = None
if self.timestep_conditioning:
@@ -590,62 +507,9 @@ class UNetMidBlock3D(nn.Module):
return hidden_states
class SpaceToDepthDownsample(nn.Module):
def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
super().__init__()
self.stride = stride
self.group_size = in_channels * math.prod(stride) // out_channels
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=out_channels // math.prod(stride),
kernel_size=3,
stride=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
def forward(self, x, causal: bool = True):
if self.stride[0] == 2:
x = torch.cat(
[x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding
# skip connection
x_in = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
x_in = x_in.mean(dim=2)
# conv
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x = x + x_in
return x
class DepthToSpaceUpsample(nn.Module):
def __init__(
self,
dims,
in_channels,
stride,
residual=False,
out_channels_reduction_factor=1,
spatial_padding_mode="zeros",
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
):
super().__init__()
self.stride = stride
@@ -659,7 +523,6 @@ class DepthToSpaceUpsample(nn.Module):
kernel_size=3,
stride=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
@@ -695,7 +558,7 @@ class DepthToSpaceUpsample(nn.Module):
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
super().__init__()
self.norm = ops.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, x):
x = rearrange(x, "b c d h w -> b d h w c")
@@ -728,7 +591,6 @@ class ResnetBlock3D(nn.Module):
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.in_channels = in_channels
@@ -755,7 +617,6 @@ class ResnetBlock3D(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
@@ -780,7 +641,6 @@ class ResnetBlock3D(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
@@ -941,44 +801,9 @@ class processor(nn.Module):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
def __init__(self, version=0, config=None):
def __init__(self, version=0):
super().__init__()
if config is None:
config = self.guess_config(version)
self.timestep_conditioning = config.get("timestep_conditioning", False)
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
self.encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
)
self.decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=self.timestep_conditioning,
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
)
self.per_channel_statistics = processor()
def guess_config(self, version):
if version == 0:
config = {
"_class_name": "CausalVideoAutoencoder",
@@ -1005,7 +830,7 @@ class VideoVAE(nn.Module):
"use_quant_conv": False,
"causal_decoder": False,
}
elif version == 1:
else:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
@@ -1041,47 +866,37 @@ class VideoVAE(nn.Module):
"causal_decoder": False,
"timestep_conditioning": True,
}
else:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"encoder_blocks": [
["res_x", {"num_layers": 4}],
["compress_space_res", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_time_res", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_all_res", {"multiplier": 2}],
["res_x", {"num_layers": 2}],
["compress_all_res", {"multiplier": 2}],
["res_x", {"num_layers": 2}]
],
"decoder_blocks": [
["res_x", {"num_layers": 5, "inject_noise": False}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 5, "inject_noise": False}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 5, "inject_noise": False}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 5, "inject_noise": False}]
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
"timestep_conditioning": True
}
return config
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
self.encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
)
self.decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=config.get("timestep_conditioning", False),
)
self.timestep_conditioning = config.get("timestep_conditioning", False)
self.per_channel_statistics = processor()
def encode(self, x):
frames_count = x.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)

View File

@@ -17,11 +17,7 @@ def make_conv_nd(
groups=1,
bias=True,
causal=False,
spatial_padding_mode="zeros",
temporal_padding_mode="zeros",
):
if not (spatial_padding_mode == temporal_padding_mode or causal):
raise NotImplementedError("spatial and temporal padding modes must be equal")
if dims == 2:
return ops.Conv2d(
in_channels=in_channels,
@@ -32,7 +28,6 @@ def make_conv_nd(
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=spatial_padding_mode,
)
elif dims == 3:
if causal:
@@ -45,7 +40,6 @@ def make_conv_nd(
dilation=dilation,
groups=groups,
bias=bias,
spatial_padding_mode=spatial_padding_mode,
)
return ops.Conv3d(
in_channels=in_channels,
@@ -56,7 +50,6 @@ def make_conv_nd(
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=spatial_padding_mode,
)
elif dims == (2, 1):
return DualConv3d(
@@ -66,7 +59,6 @@ def make_conv_nd(
stride=stride,
padding=padding,
bias=bias,
padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unsupported dimensions: {dims}")

View File

@@ -18,13 +18,11 @@ class DualConv3d(nn.Module):
dilation: Union[int, Tuple[int, int, int]] = 1,
groups=1,
bias=True,
padding_mode="zeros",
):
super(DualConv3d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.padding_mode = padding_mode
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
@@ -110,7 +108,6 @@ class DualConv3d(nn.Module):
self.padding1,
self.dilation1,
self.groups,
padding_mode=self.padding_mode,
)
if skip_time_conv:
@@ -125,7 +122,6 @@ class DualConv3d(nn.Module):
self.padding2,
self.dilation2,
self.groups,
padding_mode=self.padding_mode,
)
return x
@@ -141,16 +137,7 @@ class DualConv3d(nn.Module):
stride1 = (self.stride1[1], self.stride1[2])
padding1 = (self.padding1[1], self.padding1[2])
dilation1 = (self.dilation1[1], self.dilation1[2])
x = F.conv2d(
x,
weight1,
self.bias1,
stride1,
padding1,
dilation1,
self.groups,
padding_mode=self.padding_mode,
)
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
_, _, h, w = x.shape
@@ -167,16 +154,7 @@ class DualConv3d(nn.Module):
stride2 = self.stride2[0]
padding2 = self.padding2[0]
dilation2 = self.dilation2[0]
x = F.conv1d(
x,
weight2,
self.bias2,
stride2,
padding2,
dilation2,
self.groups,
padding_mode=self.padding_mode,
)
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x

View File

@@ -1,622 +0,0 @@
# Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py
from __future__ import annotations
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
def modulate(x, scale):
return x * (1 + scale.unsqueeze(1))
#############################################################################
# Core NextDiT Model #
#############################################################################
class JointAttention(nn.Module):
"""Multi-head attention module."""
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: Optional[int],
qk_norm: bool,
operation_settings={},
):
"""
Initialize the Attention module.
Args:
dim (int): Number of input dimensions.
n_heads (int): Number of heads.
n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
"""
super().__init__()
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
self.n_local_heads = n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = dim // n_heads
self.qkv = operation_settings.get("operations").Linear(
dim,
(n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim,
bias=False,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
self.out = operation_settings.get("operations").Linear(
n_heads * self.head_dim,
dim,
bias=False,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
if qk_norm:
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
else:
self.q_norm = self.k_norm = nn.Identity()
@staticmethod
def apply_rotary_emb(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.
This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.
Args:
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x_in.shape)
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
x_mask:
freqs_cis:
Returns:
"""
bsz, seqlen, _ = x.shape
xq, xk, xv = torch.split(
self.qkv(x),
[
self.n_local_heads * self.head_dim,
self.n_local_kv_heads * self.head_dim,
self.n_local_kv_heads * self.head_dim,
],
dim=-1,
)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq = self.q_norm(xq)
xk = self.k_norm(xk)
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
return self.out(output)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
operation_settings={},
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple
of this value.
ffn_dim_multiplier (float, optional): Custom multiplier for hidden
dimension. Defaults to None.
"""
super().__init__()
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = operation_settings.get("operations").Linear(
dim,
hidden_dim,
bias=False,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
self.w2 = operation_settings.get("operations").Linear(
hidden_dim,
dim,
bias=False,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
self.w3 = operation_settings.get("operations").Linear(
dim,
hidden_dim,
bias=False,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
# @torch.compile
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
class JointTransformerBlock(nn.Module):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
multiple_of: int,
ffn_dim_multiplier: float,
norm_eps: float,
qk_norm: bool,
modulation=True,
operation_settings={},
) -> None:
"""
Initialize a TransformerBlock.
Args:
layer_id (int): Identifier for the layer.
dim (int): Embedding dimension of the input features.
n_heads (int): Number of attention heads.
n_kv_heads (Optional[int]): Number of attention heads in key and
value features (if using GQA), or set to None for the same as
query.
multiple_of (int):
ffn_dim_multiplier (float):
norm_eps (float):
"""
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
self.feed_forward = FeedForward(
dim=dim,
hidden_dim=4 * dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
operation_settings=operation_settings,
)
self.layer_id = layer_id
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
self.modulation = modulation
if modulation:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor]=None,
):
"""
Perform a forward pass through the TransformerBlock.
Args:
x (torch.Tensor): Input tensor.
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
Returns:
torch.Tensor: Output tensor after applying attention and
feedforward layers.
"""
if self.modulation:
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
self.attention(
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
)
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
self.feed_forward(
modulate(self.ffn_norm1(x), scale_mlp),
)
)
else:
assert adaln_input is None
x = x + self.attention_norm2(
self.attention(
self.attention_norm1(x),
x_mask,
freqs_cis,
)
)
x = x + self.ffn_norm2(
self.feed_forward(
self.ffn_norm1(x),
)
)
return x
class FinalLayer(nn.Module):
"""
The final layer of NextDiT.
"""
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
super().__init__()
self.norm_final = operation_settings.get("operations").LayerNorm(
hidden_size,
elementwise_affine=False,
eps=1e-6,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
self.linear = operation_settings.get("operations").Linear(
hidden_size,
patch_size * patch_size * out_channels,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(hidden_size, 1024),
hidden_size,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
def forward(self, x, c):
scale = self.adaLN_modulation(c)
x = modulate(self.norm_final(x), scale)
x = self.linear(x)
return x
class NextDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
patch_size: int = 2,
in_channels: int = 4,
dim: int = 4096,
n_layers: int = 32,
n_refiner_layers: int = 2,
n_heads: int = 32,
n_kv_heads: Optional[int] = None,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
qk_norm: bool = False,
cap_feat_dim: int = 5120,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (1, 512, 512),
image_model=None,
device=None,
dtype=None,
operations=None,
) -> None:
super().__init__()
self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.x_embedder = operation_settings.get("operations").Linear(
in_features=patch_size * patch_size * in_channels,
out_features=dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
self.noise_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=True,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
self.context_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=False,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings),
operation_settings.get("operations").Linear(
cap_feat_dim,
dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.layers = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
operation_settings=operation_settings,
)
for layer_id in range(n_layers)
]
)
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
assert (dim // n_heads) == sum(axes_dims)
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
self.dim = dim
self.n_heads = n_heads
def unpatchify(
self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False
) -> List[torch.Tensor]:
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
pH = pW = self.patch_size
imgs = []
for i in range(x.size(0)):
H, W = img_size[i]
begin = cap_size[i]
end = begin + (H // pH) * (W // pW)
imgs.append(
x[i][begin:end]
.view(H // pH, W // pW, pH, pW, self.out_channels)
.permute(4, 0, 2, 1, 3)
.flatten(3, 4)
.flatten(1, 2)
)
if return_tensor:
imgs = torch.stack(imgs, dim=0)
return imgs
def patchify_and_embed(
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
dtype = x[0].dtype
if cap_mask is not None:
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
else:
l_effective_cap_len = [num_tokens] * bsz
if cap_mask is not None and not torch.is_floating_point(cap_mask):
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
img_sizes = [(img.size(1), img.size(2)) for img in x]
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
max_seq_len = max(
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
)
max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
H, W = img_sizes[i]
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
# build freqs_cis for cap and image individually
cap_freqs_cis_shape = list(freqs_cis.shape)
# cap_freqs_cis_shape[1] = max_cap_len
cap_freqs_cis_shape[1] = cap_feats.shape[1]
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
# refine image
flat_x = []
for i in range(bsz):
img = x[i]
C, H, W = img.size()
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
flat_x.append(img)
x = flat_x
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
for i in range(bsz):
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
else:
mask = None
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
# def forward(self, x, t, cap_feats, cap_mask):
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
"""
Forward pass of NextDiT.
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of text tokens/features
"""
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
adaln_input = t
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
x_is_tensor = isinstance(x, torch.Tensor)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
freqs_cis = freqs_cis.to(x.device)
for layer in self.layers:
x = layer(x, mask, freqs_cis, adaln_input)
x = self.final_layer(x, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
return -x

View File

@@ -1,6 +1,4 @@
import math
import sys
import torch
import torch.nn.functional as F
from torch import nn, einsum
@@ -18,11 +16,7 @@ if model_management.xformers_enabled():
import xformers.ops
if model_management.sage_attention_enabled():
try:
from sageattention import sageattn
except ModuleNotFoundError:
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
exit(-1)
from sageattention import sageattn
from comfy.cli_args import args
import comfy.ops
@@ -30,24 +24,38 @@ ops = comfy.ops.disable_weight_init
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
def get_attn_precision(attn_precision, current_dtype):
def get_attn_precision(attn_precision):
if args.dont_upcast_attention:
return None
if FORCE_UPCAST_ATTENTION_DTYPE is not None and current_dtype in FORCE_UPCAST_ATTENTION_DTYPE:
return FORCE_UPCAST_ATTENTION_DTYPE[current_dtype]
if FORCE_UPCAST_ATTENTION_DTYPE is not None:
return FORCE_UPCAST_ATTENTION_DTYPE
return attn_precision
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
@@ -82,7 +90,7 @@ def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision, q.dtype)
attn_precision = get_attn_precision(attn_precision)
if skip_reshape:
b, _, _, dim_head = q.shape
@@ -151,7 +159,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision, query.dtype)
attn_precision = get_attn_precision(attn_precision)
if skip_reshape:
b, _, _, dim_head = query.shape
@@ -221,7 +229,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision, q.dtype)
attn_precision = get_attn_precision(attn_precision)
if skip_reshape:
b, _, _, dim_head = q.shape

View File

@@ -321,7 +321,7 @@ class SelfAttention(nn.Module):
class RMSNorm(torch.nn.Module):
def __init__(
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
):
"""
Initialize the RMSNorm normalization layer.

View File

@@ -293,17 +293,6 @@ def pytorch_attention(q, k, v):
return out
def vae_attention():
if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE")
return xformers_attention
elif model_management.pytorch_attention_enabled_vae():
logging.info("Using pytorch attention in VAE")
return pytorch_attention
else:
logging.info("Using split attention in VAE")
return normal_attention
class AttnBlock(nn.Module):
def __init__(self, in_channels, conv_op=ops.Conv2d):
super().__init__()
@@ -331,7 +320,15 @@ class AttnBlock(nn.Module):
stride=1,
padding=0)
self.optimized_attention = vae_attention()
if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE")
self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention
else:
logging.info("Using split attention in VAE")
self.optimized_attention = normal_attention
def forward(self, x):
h_ = x
@@ -702,6 +699,9 @@ class Decoder(nn.Module):
padding=1)
def forward(self, z, **kwargs):
#assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None

View File

@@ -1,480 +0,0 @@
# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import torch
import torch.nn as nn
from einops import repeat
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
import comfy.ldm.common_dit
import comfy.model_management
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float32)
# calculation
sinusoid = torch.outer(
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x
class WanSelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6, operation_settings={}):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# layers
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, freqs):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n * d)
return q, k, v
q, k, v = qkv_fn(x)
q, k = apply_rope(q, k, freqs)
x = optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
v,
heads=self.num_heads,
)
x = self.o(x)
return x
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
# compute query, key, value
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(context))
v = self.v(context)
# compute attention
x = optimized_attention(q, k, v, heads=self.num_heads)
x = self.o(x)
return x
class WanI2VCrossAttention(WanSelfAttention):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6, operation_settings={}):
super().__init__(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings)
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
context_img = context[:, :257]
context = context[:, 257:]
# compute query, key, value
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(context))
v = self.v(context)
k_img = self.norm_k_img(self.k_img(context_img))
v_img = self.v_img(context_img)
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
# compute attention
x = optimized_attention(q, k, v, heads=self.num_heads)
# output
x = x + img_x
x = self.o(x)
return x
WAN_CROSSATTENTION_CLASSES = {
't2v_cross_attn': WanT2VCrossAttention,
'i2v_cross_attn': WanI2VCrossAttention,
}
class WanAttentionBlock(nn.Module):
def __init__(self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6, operation_settings={}):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
eps, operation_settings=operation_settings)
self.norm3 = operation_settings.get("operations").LayerNorm(
dim, eps,
elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if cross_attn_norm else nn.Identity()
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
num_heads,
(-1, -1),
qk_norm,
eps, operation_settings=operation_settings)
self.norm2 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.ffn = nn.Sequential(
operation_settings.get("operations").Linear(dim, ffn_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'),
operation_settings.get("operations").Linear(ffn_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
# modulation
self.modulation = nn.Parameter(torch.empty(1, 6, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
def forward(
self,
x,
e,
freqs,
context,
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, 6, C]
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
# assert e.dtype == torch.float32
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
# assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
self.norm1(x) * (1 + e[1]) + e[0],
freqs)
x = x + y * e[2]
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5]
return x
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.head = operation_settings.get("operations").Linear(dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
# modulation
self.modulation = nn.Parameter(torch.empty(1, 2, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
def forward(self, x, e):
r"""
Args:
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, C]
"""
# assert e.dtype == torch.float32
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x
class MLPProj(torch.nn.Module):
def __init__(self, in_dim, out_dim, operation_settings={}):
super().__init__()
self.proj = torch.nn.Sequential(
operation_settings.get("operations").LayerNorm(in_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").Linear(in_dim, in_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class WanModel(torch.nn.Module):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
def __init__(self,
model_type='t2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
image_model=None,
device=None,
dtype=None,
operations=None,
):
r"""
Initialize the diffusion model backbone.
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
Fixed length for text embeddings
in_dim (`int`, *optional*, defaults to 16):
Input video channels (C_in)
dim (`int`, *optional*, defaults to 2048):
Hidden dimension of the transformer
ffn_dim (`int`, *optional*, defaults to 8192):
Intermediate dimension in feed-forward network
freq_dim (`int`, *optional*, defaults to 256):
Dimension for sinusoidal time embeddings
text_dim (`int`, *optional*, defaults to 4096):
Input dimension for text embeddings
out_dim (`int`, *optional*, defaults to 16):
Output video channels (C_out)
num_heads (`int`, *optional*, defaults to 16):
Number of attention heads
num_layers (`int`, *optional*, defaults to 32):
Number of transformer blocks
window_size (`tuple`, *optional*, defaults to (-1, -1)):
Window size for local attention (-1 indicates global attention)
qk_norm (`bool`, *optional*, defaults to True):
Enable query/key normalization
cross_attn_norm (`bool`, *optional*, defaults to False):
Enable cross-attention normalization
eps (`float`, *optional*, defaults to 1e-6):
Epsilon value for normalization layers
"""
super().__init__()
self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
assert model_type in ['t2v', 'i2v']
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# embeddings
self.patch_embedding = operations.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32)
self.text_embedding = nn.Sequential(
operations.Linear(text_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'),
operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
self.time_embedding = nn.Sequential(
operations.Linear(freq_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.SiLU(), operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
self.time_projection = nn.Sequential(nn.SiLU(), operations.Linear(dim, dim * 6, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
# blocks
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
for _ in range(num_layers)
])
# head
self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
d = dim // num_heads
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
if model_type == 'i2v':
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
else:
self.img_emb = None
def forward_orig(
self,
x,
t,
context,
clip_fea=None,
freqs=None,
):
r"""
Forward pass through the diffusion model
Args:
x (Tensor):
List of input video tensors with shape [B, C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [B, L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
CLIP image features for image-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context = self.text_embedding(context)
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
freqs=freqs,
context=context)
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
r"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [L, C_out, F, H / 8, W / 8]
"""
c = self.out_dim
u = x
b = u.shape[0]
u = u[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c)
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
return u

View File

@@ -1,567 +0,0 @@
# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.modules.diffusionmodules.model import vae_attention
import comfy.ops
ops = comfy.ops.disable_weight_init
CACHE_T = 2
class CausalConv3d(ops.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else None
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == 'downsample2d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == 'downsample3d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = ops.Conv2d(dim, dim * 3, 1)
self.proj = ops.Conv2d(dim, dim, 1)
self.optimized_attention = vae_attention()
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.norm(x)
# compute query, key, value
q, k, v = self.to_qkv(x).chunk(3, dim=1)
x = self.optimized_attention(q, k, v)
# output
x = self.proj(x)
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[
i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout))
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
count += 1
return count
class WanVAE(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x):
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
# z: [b,c,t,h,w]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@@ -307,6 +307,7 @@ def model_lora_keys_unet(model, key_map={}):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
else:
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
@@ -326,13 +327,6 @@ def model_lora_keys_unet(model, key_map={}):
diffusers_lora_key = diffusers_lora_key[:-2]
key_map[diffusers_lora_key] = unet_key
if isinstance(model, comfy.model_base.StableCascade_C):
for k in sdk:
if k.startswith("diffusion_model."):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_prior_unet_{}".format(key_lora)] = k
if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
diffusers_keys = comfy.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:

View File

@@ -34,8 +34,6 @@ import comfy.ldm.flux.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.model_management
import comfy.patcher_extension
@@ -150,9 +148,7 @@ class BaseModel(torch.nn.Module):
xc = xc.to(dtype)
t = self.model_sampling.timestep(t).float()
if context is not None:
context = context.to(dtype)
context = context.to(dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
@@ -161,16 +157,15 @@ class BaseModel(torch.nn.Module):
extra = extra.to(dtype)
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def process_timestep(self, timestep, **kwargs):
return timestep
def get_dtype(self):
return self.diffusion_model.dtype
def is_adm(self):
return self.adm_channels > 0
def encode_adm(self, **kwargs):
return None
@@ -189,11 +184,6 @@ class BaseModel(torch.nn.Module):
if concat_latent_image.shape[1:] != noise.shape[1:]:
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
if noise.ndim == 5:
if concat_latent_image.shape[-3] < noise.shape[-3]:
concat_latent_image = torch.nn.functional.pad(concat_latent_image, (0, 0, 0, 0, 0, noise.shape[-3] - concat_latent_image.shape[-3]), "constant", 0)
else:
concat_latent_image = concat_latent_image[:, :, :noise.shape[-3]]
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
@@ -222,11 +212,6 @@ class BaseModel(torch.nn.Module):
cond_concat.append(self.blank_inpaint_image_like(noise))
elif ck == "mask_inverted":
cond_concat.append(torch.zeros_like(noise)[:, :1])
if ck == "concat_image":
if concat_latent_image is not None:
cond_concat.append(concat_latent_image.to(device))
else:
cond_concat.append(torch.zeros_like(noise))
data = torch.cat(cond_concat, dim=1)
return data
return None
@@ -564,10 +549,6 @@ class SD_X4Upscaler(BaseModel):
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
out['y'] = comfy.conds.CONDRegular(noise_level)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
return out
class IP2P:
@@ -825,10 +806,7 @@ class Flux(BaseModel):
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
guidance = kwargs.get("guidance", 3.5)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
return out
class GenmoMochi(BaseModel):
@@ -859,26 +837,17 @@ class LTXV(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
guiding_latent = kwargs.get("guiding_latent", None)
if guiding_latent is not None:
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None)
if guiding_latent_noise_scale is not None:
out["guiding_latent_noise_scale"] = comfy.conds.CONDConstant(guiding_latent_noise_scale)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
keyframe_idxs = kwargs.get("keyframe_idxs", None)
if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None:
return timestep
return self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
@@ -894,18 +863,9 @@ class HunyuanVideo(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
guidance = kwargs.get("guidance", 6.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))
return out
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
self.concat_keys = ("concat_image",)
class CosmosVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)
@@ -932,63 +892,3 @@ class CosmosVideo(BaseModel):
latent_image = latent_image + noise
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
class Lumina2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
def concat_cond(self, **kwargs):
if not self.image_to_video:
return None
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = self.process_latent_in(image)
image = utils.resize_to_batch_size(image, noise.shape[0])
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :4]
else:
mask = 1.0 - torch.mean(mask, dim=1, keepdim=True)
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
mask = mask.repeat(1, 4, 1, 1, 1)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((mask, image), dim=1)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
clip_vision_output = kwargs.get("clip_vision_output", None)
if clip_vision_output is not None:
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
return out

View File

@@ -1,4 +1,3 @@
import json
import comfy.supported_models
import comfy.supported_models_base
import comfy.utils
@@ -34,7 +33,7 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
return None
def detect_unet_config(state_dict, key_prefix, metadata=None):
def detect_unet_config(state_dict, key_prefix):
state_dict_keys = list(state_dict.keys())
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
@@ -137,7 +136,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
dit_config["image_model"] = "hunyuan_video"
dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
dit_config["in_channels"] = 16
dit_config["patch_size"] = [1, 2, 2]
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768
@@ -211,8 +210,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxv"
if metadata is not None and "config" in metadata:
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
return dit_config
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
@@ -242,7 +239,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["micro_condition"] = False
return dit_config
if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys: # Cosmos
if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys:
dit_config = {}
dit_config["image_model"] = "cosmos"
dit_config["max_img_h"] = 240
@@ -287,42 +284,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
dit_config = {}
dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
dit_config["dim"] = 2304
dit_config["cap_feat_dim"] = 2304
dit_config["n_layers"] = 26
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512]
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {}
dit_config["image_model"] = "wan2.1"
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
dit_config["dim"] = dim
dit_config["num_heads"] = dim // 128
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
dit_config["patch_size"] = (1, 2, 2)
dit_config["freq_dim"] = 256
dit_config["window_size"] = (-1, -1)
dit_config["qk_norm"] = True
dit_config["cross_attn_norm"] = True
dit_config["eps"] = 1e-6
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"
else:
dit_config["model_type"] = "t2v"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -457,8 +418,8 @@ def model_config_from_unet_config(unet_config, state_dict=None):
logging.error("no match {}".format(unet_config))
return None
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None):
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix)
if unet_config is None:
return None
model_config = model_config_from_unet_config(unet_config, state_dict)

View File

@@ -19,7 +19,7 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args, PerformanceFeature
from comfy.cli_args import args
import torch
import sys
import platform
@@ -50,9 +50,7 @@ xpu_available = False
torch_version = ""
try:
torch_version = torch.version.__version__
temp = torch_version.split(".")
torch_version_numeric = (int(temp[0]), int(temp[1]))
xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
except:
pass
@@ -95,13 +93,6 @@ try:
except:
npu_available = False
try:
import torch_mlu # noqa: F401
_ = torch.mlu.device_count()
mlu_available = torch.mlu.is_available()
except:
mlu_available = False
if args.cpu:
cpu_state = CPUState.CPU
@@ -119,12 +110,6 @@ def is_ascend_npu():
return True
return False
def is_mlu():
global mlu_available
if mlu_available:
return True
return False
def get_torch_device():
global directml_enabled
global cpu_state
@@ -140,8 +125,6 @@ def get_torch_device():
return torch.device("xpu", torch.xpu.current_device())
elif is_ascend_npu():
return torch.device("npu", torch.npu.current_device())
elif is_mlu():
return torch.device("mlu", torch.mlu.current_device())
else:
return torch.device(torch.cuda.current_device())
@@ -168,12 +151,6 @@ def get_total_memory(dev=None, torch_total_too=False):
_, mem_total_npu = torch.npu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_npu
elif is_mlu():
stats = torch.mlu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_mlu
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@@ -241,7 +218,7 @@ def is_amd():
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
MIN_WEIGHT_MEMORY_RATIO = 0.0
MIN_WEIGHT_MEMORY_RATIO = 0.2
ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention:
@@ -250,45 +227,22 @@ if args.use_pytorch_cross_attention:
try:
if is_nvidia():
if torch_version_numeric[0] >= 2:
if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if is_intel_xpu() or is_ascend_npu() or is_mlu():
if is_intel_xpu() or is_ascend_npu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
pass
try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logging.info("AMD arch: {}".format(arch))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx1100", "gfx1101"]): # TODO: more arches
ENABLE_PYTORCH_ATTENTION = True
except:
pass
if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try:
if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
logging.info("Enabled fp16 accumulation.")
except:
pass
try:
if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5:
if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
except:
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
@@ -302,10 +256,15 @@ elif args.highvram or args.gpu_only:
vram_state = VRAMState.HIGH_VRAM
FORCE_FP32 = False
FORCE_FP16 = False
if args.force_fp32:
logging.info("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True
if args.force_fp16:
logging.info("Forcing FP16.")
FORCE_FP16 = True
if lowvram_available:
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to
@@ -338,8 +297,6 @@ def get_torch_device_name(device):
return "{} {}".format(device, torch.xpu.get_device_name(device))
elif is_ascend_npu():
return "{} {}".format(device, torch.npu.get_device_name(device))
elif is_mlu():
return "{} {}".format(device, torch.mlu.get_device_name(device))
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
@@ -578,11 +535,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
model_size = loaded_model.model_memory_required(torch_dev)
loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 0.1
@@ -675,7 +635,7 @@ def unet_inital_load_device(parameters, dtype):
def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32], weight_dtype=None):
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if model_params < 0:
model_params = 1000000000000000000000
if args.fp32_unet:
@@ -693,8 +653,10 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
fp8_dtype = None
try:
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
fp8_dtype = weight_dtype
for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if dtype in supported_dtypes:
fp8_dtype = dtype
break
except:
pass
@@ -706,10 +668,6 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
if model_params * 2 > free_model_memory:
return fp8_dtype
if PRIORITIZE_FP16 or weight_dtype == torch.float16:
if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params):
return torch.float16
for dt in supported_dtypes:
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
if torch.float16 in supported_dtypes:
@@ -742,9 +700,6 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
if PRIORITIZE_FP16 and fp16_supported and torch.float16 in supported_dtypes:
return torch.float16
for dt in supported_dtypes:
if dt == torch.float16 and fp16_supported:
return torch.float16
@@ -930,8 +885,6 @@ def xformers_enabled():
return False
if is_ascend_npu():
return False
if is_mlu():
return False
if directml_enabled:
return False
return XFORMERS_IS_AVAILABLE
@@ -948,11 +901,6 @@ def pytorch_attention_enabled():
global ENABLE_PYTORCH_ATTENTION
return ENABLE_PYTORCH_ATTENTION
def pytorch_attention_enabled_vae():
if is_amd():
return False # enabling pytorch attention on AMD currently causes crash when doing high res
return pytorch_attention_enabled()
def pytorch_attention_flash_attention():
global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION:
@@ -963,10 +911,6 @@ def pytorch_attention_flash_attention():
return True
if is_ascend_npu():
return True
if is_mlu():
return True
if is_amd():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
return False
def mac_version():
@@ -979,11 +923,11 @@ def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
macos_version = mac_version()
if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS
if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
upcast = True
if upcast:
return {torch.float16: torch.float32}
return torch.float32
else:
return None
@@ -1013,13 +957,6 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_npu, _ = torch.npu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_npu + mem_free_torch
elif is_mlu():
stats = torch.mlu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_mlu + mem_free_torch
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
@@ -1056,26 +993,21 @@ def is_device_mps(device):
def is_device_cuda(device):
return is_device_type(device, 'cuda')
def is_directml_enabled():
global directml_enabled
if directml_enabled:
return True
return False
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled
if device is not None:
if is_device_cpu(device):
return False
if args.force_fp16:
if FORCE_FP16:
return True
if FORCE_FP32:
return False
if is_directml_enabled():
return True
if directml_enabled:
return False
if (device is not None and is_device_mps(device)) or mps_mode():
return True
@@ -1089,9 +1021,6 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_ascend_npu():
return True
if is_mlu():
return True
if torch.version.hip:
return True
@@ -1149,28 +1078,13 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_intel_xpu():
return True
if is_ascend_npu():
return True
if is_amd():
arch = torch.cuda.get_device_properties(device).gcnArchName
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
if manual_cast:
return True
return False
props = torch.cuda.get_device_properties(device)
if is_mlu():
if props.major > 3:
return True
if props.major >= 8:
return True
bf16_works = torch.cuda.is_bf16_supported()
if bf16_works and manual_cast:
if bf16_works or manual_cast:
free_model_memory = maximum_vram_for_weights(device)
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
@@ -1189,11 +1103,11 @@ def supports_fp8_compute(device=None):
if props.minor < 9:
return False
if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3):
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
return False
if WINDOWS:
if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4):
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
return False
return True

View File

@@ -96,28 +96,8 @@ def wipe_lowvram_weight(m):
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
if hasattr(m, "weight_function"):
m.weight_function = []
if hasattr(m, "bias_function"):
m.bias_function = []
def move_weight_functions(m, device):
if device is None:
return 0
memory = 0
if hasattr(m, "weight_function"):
for f in m.weight_function:
if hasattr(f, "move_to"):
memory += f.move_to(device=device)
if hasattr(m, "bias_function"):
for f in m.bias_function:
if hasattr(f, "move_to"):
memory += f.move_to(device=device)
return memory
m.weight_function = None
m.bias_function = None
class LowVramPatch:
def __init__(self, key, patches):
@@ -212,13 +192,11 @@ class ModelPatcher:
self.backup = {}
self.object_patches = {}
self.object_patches_backup = {}
self.weight_wrapper_patches = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update
self.force_cast_weights = False
self.patches_uuid = uuid.uuid4()
self.parent = None
@@ -272,14 +250,11 @@ class ModelPatcher:
n.patches_uuid = self.patches_uuid
n.object_patches = self.object_patches.copy()
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
n.force_cast_weights = self.force_cast_weights
# attachments
n.attachments = {}
for k in self.attachments:
@@ -427,16 +402,6 @@ class ModelPatcher:
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def set_model_compute_dtype(self, dtype):
self.add_object_patch("manual_cast_dtype", dtype)
if dtype is not None:
self.force_cast_weights = True
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
def add_weight_wrapper(self, name, function):
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
self.patches_uuid = uuid.uuid4()
def get_model_object(self, name: str) -> torch.nn.Module:
"""Retrieves a nested attribute from an object using dot notation considering
object patches.
@@ -601,9 +566,6 @@ class ModelPatcher:
lowvram_weight = False
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if not full_load and hasattr(m, "comfy_cast_weights"):
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
@@ -611,46 +573,34 @@ class ModelPatcher:
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
cast_weight = self.force_cast_weights
if lowvram_weight:
if hasattr(m, "comfy_cast_weights"):
m.weight_function = []
m.bias_function = []
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if lowvram_weight:
if weight_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = [LowVramPatch(weight_key, self.patches)]
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = [LowVramPatch(bias_key, self.patches)]
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1
cast_weight = True
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
else:
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)
if m.comfy_cast_weights:
wipe_lowvram_weight(m)
if full_load or mem_counter + module_mem < lowvram_model_memory:
mem_counter += module_mem
load_completely.append((module_mem, n, m, params))
if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
if weight_key in self.weight_wrapper_patches:
m.weight_function.extend(self.weight_wrapper_patches[weight_key])
if bias_key in self.weight_wrapper_patches:
m.bias_function.extend(self.weight_wrapper_patches[bias_key])
mem_counter += move_weight_functions(m, device_to)
load_completely.sort(reverse=True)
for x in load_completely:
n = x[1]
@@ -712,7 +662,6 @@ class ModelPatcher:
self.unpatch_hooks()
if self.model.model_lowvram:
for m in self.model.modules():
move_weight_functions(m, device_to)
wipe_lowvram_weight(m)
self.model.model_lowvram = False
@@ -779,19 +728,15 @@ class ModelPatcher:
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
m.to(device_to)
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
m.weight_function.append(LowVramPatch(weight_key, self.patches))
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
m.bias_function.append(LowVramPatch(bias_key, self.patches))
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1
cast_weight = True
if cast_weight:
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False

View File

@@ -31,7 +31,6 @@ class EPS:
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
@@ -62,11 +61,9 @@ class CONST:
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent):
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
return latent / (1.0 - sigma)
class ModelSamplingDiscrete(torch.nn.Module):

View File

@@ -18,7 +18,7 @@
import torch
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
from comfy.cli_args import args
import comfy.float
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
@@ -38,23 +38,21 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = len(s.bias_function) > 0
has_function = s.bias_function is not None
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
for f in s.bias_function:
bias = f(bias)
bias = s.bias_function(bias)
has_function = len(s.weight_function) > 0
has_function = s.weight_function is not None
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
for f in s.weight_function:
weight = f(weight)
weight = s.weight_function(weight)
return weight, bias
class CastWeightBiasOp:
comfy_cast_weights = False
weight_function = []
bias_function = []
weight_function = None
bias_function = None
class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
@@ -66,7 +64,7 @@ class disable_weight_init:
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -80,7 +78,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -94,7 +92,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -108,7 +106,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -122,11 +120,12 @@ class disable_weight_init:
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
@@ -140,7 +139,7 @@ class disable_weight_init:
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -161,7 +160,7 @@ class disable_weight_init:
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -182,7 +181,7 @@ class disable_weight_init:
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -200,7 +199,7 @@ class disable_weight_init:
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
if "out_dtype" in kwargs:
@@ -360,11 +359,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
if (
fp8_compute and
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
not disable_fast_fp8
):
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops
if compute_dtype is None or weight_dtype == compute_dtype:

View File

@@ -58,6 +58,7 @@ def convert_cond(cond):
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
temp["uuid"] = uuid.uuid4()

View File

@@ -12,6 +12,7 @@ import collections
from comfy import model_management
import math
import logging
import comfy.samplers
import comfy.sampler_helpers
import comfy.model_patcher
import comfy.patcher_extension
@@ -19,12 +20,6 @@ import comfy.hooks
import scipy.stats
import numpy
def add_area_dims(area, num_dims):
while (len(area) // 2) < num_dims:
area = [2147483648] + area[:len(area) // 2] + [0] + area[len(area) // 2:]
return area
def get_area_and_mult(conds, x_in, timestep_in):
dims = tuple(x_in.shape[2:])
area = None
@@ -40,10 +35,6 @@ def get_area_and_mult(conds, x_in, timestep_in):
return None
if 'area' in conds:
area = list(conds['area'])
area = add_area_dims(area, len(dims))
if (len(area) // 2) > len(dims):
area = area[:len(dims)] + area[len(area) // 2:(len(area) // 2) + len(dims)]
if 'strength' in conds:
strength = conds['strength']
@@ -60,7 +51,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert (mask.shape[1:] == x_in.shape[2:])
assert(mask.shape[1:] == x_in.shape[2:])
mask = mask[:input_x.shape[0]]
if area is not None:
@@ -74,17 +65,16 @@ def get_area_and_mult(conds, x_in, timestep_in):
mult = mask * strength
if 'mask' not in conds and area is not None:
fuzz = 8
rr = 8
for i in range(len(dims)):
rr = min(fuzz, mult.shape[2 + i] // 4)
if area[len(dims) + i] != 0:
for t in range(rr):
m = mult.narrow(i + 2, t, 1)
m *= ((1.0 / rr) * (t + 1))
m *= ((1.0/rr) * (t + 1))
if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]:
for t in range(rr):
m = mult.narrow(i + 2, area[i] - 1 - t, 1)
m *= ((1.0 / rr) * (t + 1))
m *= ((1.0/rr) * (t + 1))
conditioning = {}
model_conds = conds["model_conds"]
@@ -188,7 +178,7 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
cond = default_conds[i]
for x in cond:
# do get_area_and_mult to get all the expected values
p = get_area_and_mult(x, x_in, timestep)
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
if p is None:
continue
# replace p's mult with calculated mult
@@ -225,7 +215,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
default_c.append(x)
has_default_conds = True
continue
p = get_area_and_mult(x, x_in, timestep)
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
if p is None:
continue
if p.hooks is not None:
@@ -559,37 +549,25 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.")
return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device)
def create_cond_with_same_area_if_none(conds, c):
def create_cond_with_same_area_if_none(conds, c): #TODO: handle dim != 2
if 'area' not in c:
return
def area_inside(a, area_cmp):
a = add_area_dims(a, len(area_cmp) // 2)
area_cmp = add_area_dims(area_cmp, len(a) // 2)
a_l = len(a) // 2
area_cmp_l = len(area_cmp) // 2
for i in range(min(a_l, area_cmp_l)):
if a[a_l + i] < area_cmp[area_cmp_l + i]:
return False
for i in range(min(a_l, area_cmp_l)):
if (a[i] + a[a_l + i]) > (area_cmp[i] + area_cmp[area_cmp_l + i]):
return False
return True
c_area = c['area']
smallest = None
for x in conds:
if 'area' in x:
a = x['area']
if area_inside(c_area, a):
if smallest is None:
smallest = x
elif 'area' not in smallest:
smallest = x
else:
if math.prod(smallest['area'][:len(smallest['area']) // 2]) > math.prod(a[:len(a) // 2]):
smallest = x
if c_area[2] >= a[2] and c_area[3] >= a[3]:
if a[0] + a[2] >= c_area[0] + c_area[2]:
if a[1] + a[3] >= c_area[1] + c_area[3]:
if smallest is None:
smallest = x
elif 'area' not in smallest:
smallest = x
else:
if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]:
smallest = x
else:
if smallest is None:
smallest = x
@@ -709,8 +687,7 @@ class Sampler:
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation"]
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp"]
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
import json
import torch
from enum import Enum
import logging
@@ -13,7 +12,6 @@ from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import yaml
import math
@@ -38,8 +36,6 @@ import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.model_patcher
import comfy.lora
@@ -135,8 +131,8 @@ class CLIP:
def clip_layer(self, layer_idx):
self.layer_idx = layer_idx
def tokenize(self, text, return_word_ids=False, **kwargs):
return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def add_hooks_to_dict(self, pooled_dict: dict[str]):
if self.apply_hooks_to_conds:
@@ -250,7 +246,7 @@ class CLIP:
return self.patcher.get_key_patches()
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
def __init__(self, sd=None, device=None, config=None, dtype=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
@@ -358,12 +354,7 @@ class VAE:
version = 0
elif tensor_conv1.shape[0] == 1024:
version = 1
if "encoder.down_blocks.1.conv.conv.bias" in sd:
version = 2
vae_config = None
if metadata is not None and "config" in metadata:
vae_config = json.loads(metadata["config"]).get("vae", None)
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config)
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version)
self.latent_channels = 128
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
@@ -397,21 +388,9 @@ class VAE:
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
#TODO: these values are a bit off because this is not a standard VAE
self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (220 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (500 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.middle.0.residual.0.gamma" in sd:
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -678,8 +657,6 @@ class CLIPType(Enum):
HUNYUAN_VIDEO = 9
PIXART = 10
COSMOS = 11
LUMINA2 = 12
WAN = 13
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -698,7 +675,6 @@ class TEModel(Enum):
T5_BASE = 6
LLAMA3_8 = 7
T5_XXL_OLD = 8
GEMMA_2_2B = 9
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -717,8 +693,6 @@ def detect_te_model(sd):
return TEModel.T5_XXL_OLD
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
return TEModel.GEMMA_2_2B
if "model.layers.0.post_attention_layernorm.weight" in sd:
return TEModel.LLAMA3_8
return None
@@ -756,7 +730,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
tokenizer_data = {}
clip_target = EmptyClass()
clip_target.params = {}
if len(clip_data) == 1:
@@ -784,10 +757,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.PIXART:
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
elif clip_type == CLIPType.WAN:
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
@@ -800,10 +769,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.T5_BASE:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
else:
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
@@ -833,6 +798,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
parameters = 0
tokenizer_data = {}
for c in clip_data:
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
@@ -879,13 +845,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
sd = comfy.utils.load_torch_file(ckpt_path)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
clip = None
clipvision = None
vae = None
@@ -897,19 +863,19 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
if model_config is None:
return None
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None:
weight_dtype = None
if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype)
model_config.custom_operations = model_options.get("custom_operations", None)
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
@@ -926,7 +892,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd, metadata=metadata)
vae = VAE(sd=vae_sd)
if output_clip:
clip_target = model_config.clip_target(state_dict=sd)
@@ -1000,11 +966,11 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None:
weight_dtype = None
if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype)
if dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
else:
unet_dtype = dtype

View File

@@ -421,10 +421,10 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}, tokenizer_args={}):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length
self.min_length = min_length
self.end_token = None
@@ -482,7 +482,7 @@ class SDTokenizer:
return (embed, leftover)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
def tokenize_with_weights(self, text:str, return_word_ids=False):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors.
@@ -585,18 +585,13 @@ class SDTokenizer:
return {}
class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
if name is not None:
self.clip_name = name
self.clip = "{}".format(self.clip_name)
else:
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
return out
@@ -605,7 +600,7 @@ class SD1Tokenizer:
return getattr(self, self.clip).untokenize(token_weight_pair)
def state_dict(self):
return getattr(self, self.clip).state_dict()
return {}
class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):

View File

@@ -26,7 +26,7 @@ class SDXLTokenizer:
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)

View File

@@ -15,8 +15,6 @@ import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
from . import supported_models_base
from . import latent_formats
@@ -762,7 +760,7 @@ class LTXV(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.LTXV
memory_usage_factor = 5.5 # TODO: img2vid is about 2x vs txt2vid
memory_usage_factor = 2.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@@ -790,7 +788,7 @@ class HunyuanVideo(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 1.8 #TODO
memory_usage_factor = 2.0 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@@ -826,16 +824,6 @@ class HunyuanVideo(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"in_channels": 32,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideoSkyreelsI2V(self, device=device)
return out
class CosmosT2V(supported_models_base.BASE):
unet_config = {
"image_model": "cosmos",
@@ -851,7 +839,7 @@ class CosmosT2V(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Cosmos1CV8x8x8
memory_usage_factor = 1.6 #TODO
memory_usage_factor = 2.4 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
@@ -877,78 +865,6 @@ class CosmosI2V(CosmosT2V):
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
return out
class Lumina2(supported_models_base.BASE):
unet_config = {
"image_model": "lumina2",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 6.0,
}
memory_usage_factor = 1.2
unet_extra_config = {}
latent_format = latent_formats.Flux
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Lumina2(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
"model_type": "t2v",
}
sampling_settings = {
"shift": 8.0,
}
unet_extra_config = {}
latent_format = latent_formats.Wan21
memory_usage_factor = 1.0
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
class WAN21_I2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "i2v",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, image_to_video=True, device=device)
return out
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V]
models += [SVD_img2vid]

View File

@@ -118,7 +118,7 @@ class BertModel_(torch.nn.Module):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
x, i = self.encoder(x, mask, intermediate_output)
return x, i

View File

@@ -18,7 +18,7 @@ class FluxTokenizer:
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)

View File

@@ -41,14 +41,11 @@ class HunyuanVideoTokenizer:
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
def tokenize_with_weights(self, text:str, return_word_ids=False, llama_template=None, **kwargs):
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
if llama_template is None:
llama_text = "{}{}".format(self.llama_template, text)
else:
llama_text = "{}{}".format(llama_template, text)
llama_text = "{}{}".format(self.llama_template, text)
out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids)
return out

View File

@@ -37,7 +37,7 @@ class HyditTokenizer:
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)

View File

@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Any
@@ -20,41 +21,15 @@ class Llama2Config:
max_position_embeddings: int = 8192
rms_norm_eps: float = 1e-5
rope_theta: float = 500000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
@dataclass
class Gemma2_2B_Config:
vocab_size: int = 256000
hidden_size: int = 2304
intermediate_size: int = 9216
num_hidden_layers: int = 26
num_attention_heads: int = 8
num_key_value_heads: int = 4
max_position_embeddings: int = 8192
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
transformer_type: str = "gemma2"
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
def __init__(self, dim: int, eps: float = 1e-5, device=None, dtype=None):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
self.add = add
def forward(self, x: torch.Tensor):
w = self.weight
if self.add:
w = w + 1.0
return comfy.ldm.common_dit.rms_norm(x, w, self.eps)
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
def rotate_half(x):
@@ -93,15 +68,13 @@ class Attention(nn.Module):
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.hidden_size = config.hidden_size
self.head_dim = config.head_dim
self.inner_size = self.num_heads * self.head_dim
self.head_dim = self.hidden_size // self.num_heads
ops = ops or nn
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
self.q_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
self.o_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
def forward(
self,
@@ -111,6 +84,7 @@ class Attention(nn.Module):
optimized_attention=None,
):
batch_size, seq_length, _ = hidden_states.shape
xq = self.q_proj(hidden_states)
xk = self.k_proj(hidden_states)
xv = self.v_proj(hidden_states)
@@ -134,13 +108,9 @@ class MLP(nn.Module):
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
if config.mlp_activation == "silu":
self.activation = torch.nn.functional.silu
elif config.mlp_activation == "gelu_pytorch_tanh":
self.activation = lambda a: torch.nn.functional.gelu(a, approximate="tanh")
def forward(self, x):
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
@@ -176,45 +146,6 @@ class TransformerBlock(nn.Module):
return x
class TransformerBlockGemma2(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
# Self Attention
residual = x
x = self.input_layernorm(x)
x = self.self_attn(
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
)
x = self.post_attention_layernorm(x)
x = residual + x
# MLP
residual = x
x = self.pre_feedforward_layernorm(x)
x = self.mlp(x)
x = self.post_feedforward_layernorm(x)
x = residual + x
return x
class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
@@ -227,27 +158,17 @@ class Llama2_(nn.Module):
device=device,
dtype=dtype
)
if self.config.transformer_type == "gemma2":
transformer = TransformerBlockGemma2
self.normalize_in = True
else:
transformer = TransformerBlock
self.normalize_in = False
self.layers = nn.ModuleList([
transformer(config, device=device, dtype=dtype, ops=ops)
TransformerBlock(config, device=device, dtype=dtype, ops=ops)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
x = self.embed_tokens(x, out_dtype=dtype)
if self.normalize_in:
x *= self.config.hidden_size ** 0.5
freqs_cis = precompute_freqs_cis(self.config.head_dim,
freqs_cis = precompute_freqs_cis(self.config.hidden_size // self.config.num_attention_heads,
x.shape[1],
self.config.rope_theta,
device=x.device)
@@ -285,18 +206,8 @@ class Llama2_(nn.Module):
return x, intermediate
class BaseLlama:
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, embeddings):
self.model.embed_tokens = embeddings
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)
class Llama2(BaseLlama, torch.nn.Module):
class Llama2(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Llama2Config(**config_dict)
@@ -305,12 +216,11 @@ class Llama2(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
def get_input_embeddings(self):
return self.model.embed_tokens
class Gemma2_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma2_2B_Config(**config_dict)
self.num_layers = config.num_hidden_layers
def set_input_embeddings(self, embeddings):
self.model.embed_tokens = embeddings
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)

View File

@@ -1,39 +0,0 @@
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.llama
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False})
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None):
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return LuminaTEModel_

View File

@@ -43,7 +43,7 @@ class SD3Tokenizer:
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)

View File

@@ -1,21 +1,21 @@
import torch
class SPieceTokenizer:
@staticmethod
def from_pretrained(path, **kwargs):
return SPieceTokenizer(path, **kwargs)
add_eos = True
def __init__(self, tokenizer_path, add_bos=False, add_eos=True):
self.add_bos = add_bos
self.add_eos = add_eos
@staticmethod
def from_pretrained(path):
return SPieceTokenizer(path)
def __init__(self, tokenizer_path):
import sentencepiece
if torch.is_tensor(tokenizer_path):
tokenizer_path = tokenizer_path.numpy().tobytes()
if isinstance(tokenizer_path, bytes):
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_eos=self.add_eos)
else:
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_eos=self.add_eos)
def get_vocab(self):
out = {}

View File

@@ -203,7 +203,7 @@ class T5Stack(torch.nn.Module):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
intermediate = None
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)

View File

@@ -1,22 +0,0 @@
{
"d_ff": 10240,
"d_kv": 64,
"d_model": 4096,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "umt5",
"num_decoder_layers": 24,
"num_heads": 64,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 256384
}

View File

@@ -1,37 +0,0 @@
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.t5
import os
class UMT5XXlModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class WanT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="umt5xxl", tokenizer=UMT5XXlTokenizer)
class WanT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
class WanTEModel(WanT5Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
if dtype_t5 is not None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return WanTEModel

View File

@@ -43,29 +43,13 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
device = torch.device("cpu")
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
sd[k] = f.get_tensor(k)
if return_metadata:
metadata = f.metadata()
except Exception as e:
if len(e.args) > 0:
message = e.args[0]
if "HeaderTooLarge" in message:
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.".format(message, ckpt))
if "MetadataIncompleteBuffer" in message:
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
raise e
sd = safetensors.torch.load_file(ckpt, device=device.type)
else:
if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
@@ -83,7 +67,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
sd = pl_sd
else:
sd = pl_sd
return (sd, metadata) if return_metadata else sd
return sd
def save_torch_file(sd, ckpt, metadata=None):
if metadata is not None:

View File

@@ -71,8 +71,8 @@ class CosmosImageToVideoLatent:
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
out_latent = {}
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
out_latent["samples"] = latent
out_latent["noise_mask"] = mask
return (out_latent,)

View File

@@ -38,26 +38,7 @@ class FluxGuidance:
return (c, )
class FluxDisableGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
def append(self, conditioning):
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
return (c, )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance,
}

View File

@@ -2,14 +2,10 @@ import comfy.utils
import comfy_extras.nodes_post_processing
import torch
def reshape_latent_to(target_shape, latent, repeat_batch=True):
def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]:
latent = comfy.utils.common_upscale(latent, target_shape[-1], target_shape[-2], "bilinear", "center")
if repeat_batch:
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
else:
return latent
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
class LatentAdd:
@@ -120,7 +116,8 @@ class LatentBatch:
s1 = samples1["samples"]
s2 = samples2["samples"]
s2 = reshape_latent_to(s1.shape, s2, repeat_batch=False)
if s1.shape[1:] != s2.shape[1:]:
s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center")
s = torch.cat((s1, s2), dim=0)
samples_out["samples"] = s
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])

View File

@@ -19,8 +19,14 @@ class Load3D():
"image": ("LOAD_3D", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"show_grid": ([True, False],),
"camera_type": (["perspective", "orthographic"],),
"view": (["front", "right", "top", "isometric"],),
"material": (["original", "normal", "wireframe", "depth"],),
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
@@ -32,14 +38,22 @@ class Load3D():
CATEGORY = "3d"
def process(self, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
if isinstance(image, dict):
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
return output_image, output_mask, model_file,
return output_image, output_mask, model_file,
else:
# to avoid the format is not dict which will happen the FE code is not compatibility to core,
# we need to this to double-check, it can be removed after merged FE into the core
image_path = folder_paths.get_annotated_filepath(image)
load_image_node = nodes.LoadImage()
output_image, output_mask = load_image_node.load_image(image=image_path)
return output_image, output_mask, model_file,
class Load3DAnimation():
@classmethod
@@ -55,8 +69,15 @@ class Load3DAnimation():
"image": ("LOAD_3D_ANIMATION", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"show_grid": ([True, False],),
"camera_type": (["perspective", "orthographic"],),
"view": (["front", "right", "top", "isometric"],),
"material": (["original", "normal", "wireframe", "depth"],),
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
"animation_speed": (["0.1", "0.5", "1", "1.5", "2"], {"default": "1"}),
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
@@ -68,42 +89,34 @@ class Load3DAnimation():
CATEGORY = "3d"
def process(self, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
if isinstance(image, dict):
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
return output_image, output_mask, model_file,
return output_image, output_mask, model_file,
else:
image_path = folder_paths.get_annotated_filepath(image)
load_image_node = nodes.LoadImage()
output_image, output_mask = load_image_node.load_image(image=image_path)
return output_image, output_mask, model_file,
class Preview3D():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
"show_grid": ([True, False],),
"camera_type": (["perspective", "orthographic"],),
"view": (["front", "right", "top", "isometric"],),
"material": (["original", "normal", "wireframe", "depth"],),
"bg_color": ("STRING", {"default": "#000000", "multiline": False}),
"light_intensity": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1}),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}}
OUTPUT_NODE = True
RETURN_TYPES = ()
CATEGORY = "3d"
FUNCTION = "process"
EXPERIMENTAL = True
def process(self, model_file, **kwargs):
return {"ui": {"model_file": [model_file]}, "result": ()}
class Preview3DAnimation():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
"fov": ("INT", {"default": 75, "min": 10, "max": 150, "step": 1}),
}}
OUTPUT_NODE = True
@@ -120,13 +133,11 @@ class Preview3DAnimation():
NODE_CLASS_MAPPINGS = {
"Load3D": Load3D,
"Load3DAnimation": Load3DAnimation,
"Preview3D": Preview3D,
"Preview3DAnimation": Preview3DAnimation
"Preview3D": Preview3D
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Load3D": "Load 3D",
"Load3DAnimation": "Load 3D - Animation",
"Preview3D": "Preview 3D",
"Preview3DAnimation": "Preview 3D - Animation"
"Preview3D": "Preview 3D"
}

View File

@@ -1,14 +1,9 @@
import io
import nodes
import node_helpers
import torch
import comfy.model_management
import comfy.model_sampling
import comfy.utils
import math
import numpy as np
import av
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
class EmptyLTXVLatentVideo:
@classmethod
@@ -38,6 +33,7 @@ class LTXVImgToVideo:
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 1.0, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."})
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
@@ -46,220 +42,16 @@ class LTXVImgToVideo:
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def generate(self, positive, negative, image, vae, width, height, length, batch_size):
def generate(self, positive, negative, image, vae, width, height, length, batch_size, image_noise_scale):
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
latent[:, :, :t.shape[2]] = t
conditioning_latent_frames_mask = torch.ones(
(batch_size, 1, latent.shape[2], 1, 1),
dtype=torch.float32,
device=latent.device,
)
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, )
def conditioning_get_any_value(conditioning, key, default=None):
for t in conditioning:
if key in t[1]:
return t[1][key]
return default
def get_noise_mask(latent):
noise_mask = latent.get("noise_mask", None)
latent_image = latent["samples"]
if noise_mask is None:
batch_size, _, latent_length, _, _ = latent_image.shape
noise_mask = torch.ones(
(batch_size, 1, latent_length, 1, 1),
dtype=torch.float32,
device=latent_image.device,
)
else:
noise_mask = noise_mask.clone()
return noise_mask
def get_keyframe_idxs(cond):
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
if keyframe_idxs is None:
return None, 0
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
return keyframe_idxs, num_keyframes
class LTXVAddGuide:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE",),
"latent": ("LATENT",),
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
"tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \
"If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \
"Negative values are counted from the end of the video."}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def __init__(self):
self._num_prefix_frames = 2
self._patchifier = SymmetricPatchifier(1)
def encode(self, vae, latent_width, latent_height, images, scale_factors):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
return encode_pixels, t
def get_latent_index(self, cond, latent_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors
_, num_keyframes = get_keyframe_idxs(cond)
latent_count = latent_length - num_keyframes
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0)
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
return frame_idx, latent_idx
def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors):
keyframe_idxs, _ = get_keyframe_idxs(cond)
_, latent_coords = self._patchifier.patchify(guiding_latent)
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, True)
pixel_coords[:, 0] += frame_idx
if keyframe_idxs is None:
keyframe_idxs = pixel_coords
else:
keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1),
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,
)
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
noise_mask = torch.cat([noise_mask, mask], dim=2)
return positive, negative, latent_image, noise_mask
def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength):
cond_length = guiding_latent.shape[2]
assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence."
mask = torch.full(
(noise_mask.shape[0], 1, cond_length, 1, 1),
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,
)
latent_image = latent_image.clone()
noise_mask = noise_mask.clone()
latent_image[:, :, latent_idx : latent_idx + cond_length] = guiding_latent
noise_mask[:, :, latent_idx : latent_idx + cond_length] = mask
return latent_image, noise_mask
def generate(self, positive, negative, vae, latent, image, frame_idx, strength):
scale_factors = vae.downscale_index_formula
latent_image = latent["samples"]
noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
if frame_idx == 0:
latent_image, noise_mask = self.replace_latent_frames(latent_image, noise_mask, t, latent_idx, strength)
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
positive, negative, latent_image, noise_mask = self.append_keyframe(
positive,
negative,
frame_idx,
latent_image,
noise_mask,
t[:, :, :num_prefix_frames],
strength,
scale_factors,
)
latent_idx += num_prefix_frames
t = t[:, :, num_prefix_frames:]
if t.shape[2] == 0:
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
latent_image, noise_mask = self.replace_latent_frames(
latent_image,
noise_mask,
t,
latent_idx,
strength,
)
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
class LTXVCropGuides:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent": ("LATENT",),
}
}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
CATEGORY = "conditioning/video_models"
FUNCTION = "crop"
def __init__(self):
self._patchifier = SymmetricPatchifier(1)
def crop(self, positive, negative, latent):
latent_image = latent["samples"].clone()
noise_mask = get_noise_mask(latent)
_, num_keyframes = get_keyframe_idxs(positive)
latent_image = latent_image[:, :, :-num_keyframes]
noise_mask = noise_mask[:, :, :-num_keyframes]
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
return (positive, negative, {"samples": latent}, )
class LTXVConditioning:
@@ -382,78 +174,6 @@ class LTXVScheduler:
return (sigmas,)
def encode_single_frame(output_file, image_array: np.ndarray, crf):
container = av.open(output_file, "w", format="mp4")
try:
stream = container.add_stream(
"h264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
)
stream.height = image_array.shape[0]
stream.width = image_array.shape[1]
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
format="yuv420p"
)
container.mux(stream.encode(av_frame))
container.mux(stream.encode())
finally:
container.close()
def decode_single_frame(video_file):
container = av.open(video_file)
try:
stream = next(s for s in container.streams if s.type == "video")
frame = next(container.decode(stream))
finally:
container.close()
return frame.to_ndarray(format="rgb24")
def preprocess(image: torch.Tensor, crf=29):
if crf == 0:
return image
image_array = (image * 255.0).byte().cpu().numpy()
with io.BytesIO() as output_file:
encode_single_frame(output_file, image_array, crf)
video_bytes = output_file.getvalue()
with io.BytesIO(video_bytes) as video_file:
image_array = decode_single_frame(video_file)
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
return tensor
class LTXVPreprocess:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"img_compression": (
"INT",
{
"default": 35,
"min": 0,
"max": 100,
"tooltip": "Amount of compression to apply on image.",
},
),
}
}
FUNCTION = "preprocess"
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("output_image",)
CATEGORY = "image"
def preprocess(self, image, img_compression):
output_image = image
if img_compression > 0:
output_image = torch.zeros_like(image)
for i in range(image.shape[0]):
output_image[i] = preprocess(image[i], img_compression)
return (output_image,)
NODE_CLASS_MAPPINGS = {
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
@@ -461,7 +181,4 @@ NODE_CLASS_MAPPINGS = {
"ModelSamplingLTXV": ModelSamplingLTXV,
"LTXVConditioning": LTXVConditioning,
"LTXVScheduler": LTXVScheduler,
"LTXVAddGuide": LTXVAddGuide,
"LTXVPreprocess": LTXVPreprocess,
"LTXVCropGuides": LTXVCropGuides,
}

View File

@@ -1,104 +0,0 @@
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
import torch
class RenormCFG:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"cfg_trunc": ("FLOAT", {"default": 100, "min": 0.0, "max": 100.0, "step": 0.01}),
"renorm_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, cfg_trunc, renorm_cfg):
def renorm_cfg_func(args):
cond_denoised = args["cond_denoised"]
uncond_denoised = args["uncond_denoised"]
cond_scale = args["cond_scale"]
timestep = args["timestep"]
x_orig = args["input"]
in_channels = model.model.diffusion_model.in_channels
if timestep[0] < cfg_trunc:
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
half_eps = uncond_eps + cond_scale * (cond_eps - uncond_eps)
half_rest = cond_rest
if float(renorm_cfg) > 0.0:
ori_pos_norm = torch.linalg.vector_norm(cond_eps
, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True
)
max_new_norm = ori_pos_norm * float(renorm_cfg)
new_pos_norm = torch.linalg.vector_norm(
half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
)
if new_pos_norm >= max_new_norm:
half_eps = half_eps * (max_new_norm / new_pos_norm)
else:
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
half_eps = cond_eps
half_rest = cond_rest
cfg_result = torch.cat([half_eps, half_rest], dim=1)
# cfg_result = uncond_denoised + (cond_denoised - uncond_denoised) * cond_scale
return x_orig - cfg_result
m = model.clone()
m.set_model_sampler_cfg_function(renorm_cfg_func)
return (m, )
class CLIPTextEncodeLumina2(ComfyNodeABC):
SYSTEM_PROMPT = {
"superior": "You are an assistant designed to generate superior images with the superior "\
"degree of image-text alignment based on textual prompts or user prompts.",
"alignment": "You are an assistant designed to generate high-quality images with the "\
"highest degree of image-text alignment based on textual prompts."
}
SYSTEM_PROMPT_TIP = "Lumina2 provide two types of system prompts:" \
"Superior: You are an assistant designed to generate superior images with the superior "\
"degree of image-text alignment based on textual prompts or user prompts. "\
"Alignment: You are an assistant designed to generate high-quality images with the highest "\
"degree of image-text alignment based on textual prompts."
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"system_prompt": (list(CLIPTextEncodeLumina2.SYSTEM_PROMPT.keys()), {"tooltip": CLIPTextEncodeLumina2.SYSTEM_PROMPT_TIP}),
"user_prompt": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
}
}
RETURN_TYPES = (IO.CONDITIONING,)
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
FUNCTION = "encode"
CATEGORY = "conditioning"
DESCRIPTION = "Encodes a system prompt and a user prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
def encode(self, clip, user_prompt, system_prompt):
if clip is None:
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
system_prompt = CLIPTextEncodeLumina2.SYSTEM_PROMPT[system_prompt]
prompt = f'{system_prompt} <Prompt Start> {user_prompt}'
tokens = clip.tokenize(prompt)
return (clip.encode_from_tokens_scheduled(tokens), )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeLumina2": CLIPTextEncodeLumina2,
"RenormCFG": RenormCFG
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CLIPTextEncodeLumina2": "CLIP Text Encode for Lumina2",
}

View File

@@ -3,8 +3,6 @@ import comfy.model_sampling
import comfy.latent_formats
import nodes
import torch
import node_helpers
class LCM(comfy.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
@@ -296,24 +294,6 @@ class RescaleCFG:
m.set_model_sampler_cfg_function(rescale_cfg)
return (m, )
class ModelComputeDtype:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"dtype": (["default", "fp32", "fp16", "bf16"],),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/debug/model"
def patch(self, model, dtype):
m = model.clone()
m.set_model_compute_dtype(node_helpers.string_to_torch_dtype(dtype))
return (m, )
NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
@@ -323,5 +303,4 @@ NODE_CLASS_MAPPINGS = {
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
"ModelSamplingFlux": ModelSamplingFlux,
"RescaleCFG": RescaleCFG,
"ModelComputeDtype": ModelComputeDtype,
}

View File

@@ -196,54 +196,6 @@ class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeCosmos7B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["extra_pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["affline_norm."] = argument
for i in range(28):
arg_dict["blocks.block{}.".format(i)] = argument
arg_dict["final_layer."] = argument
return {"required": arg_dict}
class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["extra_pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["affline_norm."] = argument
for i in range(36):
arg_dict["blocks.block{}.".format(i)] = argument
arg_dict["final_layer."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
@@ -254,6 +206,4 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeSD35_Large": ModelMergeSD35_Large,
"ModelMergeMochiPreview": ModelMergeMochiPreview,
"ModelMergeLTXV": ModelMergeLTXV,
"ModelMergeCosmos7B": ModelMergeCosmos7B,
"ModelMergeCosmos14B": ModelMergeCosmos14B,
}

View File

@@ -1,76 +0,0 @@
import os
import av
import torch
import folder_paths
import json
from fractions import Fraction
class SaveWEBM:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"codec": (["vp9", "av1"],),
"fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"crf": ("FLOAT", {"default": 32.0, "min": 0, "max": 63.0, "step": 1, "tooltip": "Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "image/video"
EXPERIMENTAL = True
def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
file = f"{filename}_{counter:05}_.webm"
container = av.open(os.path.join(full_output_folder, file), mode="w")
if prompt is not None:
container.metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
container.metadata[x] = json.dumps(extra_pnginfo[x])
codec_map = {"vp9": "libvpx-vp9", "av1": "libaom-av1"}
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
stream.width = images.shape[-2]
stream.height = images.shape[-3]
stream.pix_fmt = "yuv420p"
stream.bit_rate = 0
stream.options = {'crf': str(crf)}
for frame in images:
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
container.mux(stream.encode())
container.close()
results = [{
"filename": file,
"subfolder": subfolder,
"type": self.type
}]
return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side
NODE_CLASS_MAPPINGS = {
"SaveWEBM": SaveWEBM,
}

View File

@@ -4,7 +4,6 @@ import comfy.utils
import comfy.sd
import folder_paths
import comfy_extras.nodes_model_merging
import node_helpers
class ImageOnlyCheckpointLoader:
@@ -122,38 +121,12 @@ class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
class ConditioningSetAreaPercentageVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"temporal": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"z": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, conditioning, width, height, temporal, x, y, z, strength):
c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", temporal, height, width, z, y, x),
"strength": strength,
"set_area_to_bounds": False})
return (c, )
NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
"ConditioningSetAreaPercentageVideo": ConditioningSetAreaPercentageVideo,
}
NODE_DISPLAY_NAME_MAPPINGS = {

View File

@@ -1,54 +0,0 @@
import nodes
import node_helpers
import torch
import comfy.model_management
import comfy.utils
class WanImageToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
image[:start_image.shape[0]] = start_image
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"WanImageToVideo": WanImageToVideo,
}

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.20"
__version__ = "0.3.10"

View File

@@ -7,18 +7,11 @@ import logging
from typing import Literal
from collections.abc import Collection
from comfy.cli_args import args
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {}
# --base-directory - Resets all default paths configured in folder_paths with a new base path
if args.base_directory:
base_path = os.path.abspath(args.base_directory)
else:
base_path = os.path.dirname(os.path.realpath(__file__))
base_path = os.path.dirname(os.path.realpath(__file__))
models_dir = os.path.join(base_path, "models")
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions)
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
@@ -46,10 +39,10 @@ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")]
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input")
user_directory = os.path.join(base_path, "user")
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}

View File

@@ -12,10 +12,7 @@ MAX_PREVIEW_RESOLUTION = args.preview_size
def preview_to_image(latent_image):
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
)
if comfy.model_management.directml_enabled:
latents_ubyte = latents_ubyte.to(dtype=torch.uint8)
latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
return Image.fromarray(latents_ubyte.numpy())

14
main.py
View File

@@ -138,8 +138,8 @@ import server
from server import BinaryEventTypes
import nodes
import comfy.model_management
import comfyui_version
from app.database.db import can_create_session, init_db
from app.model_processor import model_processor
def cuda_malloc_warning():
device = comfy.model_management.get_torch_device()
@@ -264,6 +264,11 @@ def start_comfyui(asyncio_loop=None):
cuda_malloc_warning()
try:
init_db()
except Exception as e:
logging.error(f"Failed to initialize database. Please report this error as in future the database will be required: {e}")
prompt_server.add_routes()
hijack_progress(prompt_server)
@@ -271,6 +276,10 @@ def start_comfyui(asyncio_loop=None):
if args.quick_test_for_ci:
exit(0)
# Scan for changed model files and update db
if can_create_session():
model_processor.run()
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
call_on_start = None
@@ -294,7 +303,6 @@ def start_comfyui(asyncio_loop=None):
if __name__ == "__main__":
# Running directly, just start ComfyUI.
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
event_loop, _, start_all_func = start_comfyui()
try:
event_loop.run_until_complete(start_all_func())

View File

@@ -1,5 +1,4 @@
import hashlib
import torch
from comfy.cli_args import args
@@ -36,11 +35,3 @@ def hasher():
"sha512": hashlib.sha512
}
return hashfuncs[args.default_hashing_function]
def string_to_torch_dtype(string):
if string == "fp32":
return torch.float32
if string == "fp16":
return torch.float16
if string == "bf16":
return torch.bfloat16

View File

@@ -63,8 +63,6 @@ class CLIPTextEncode(ComfyNodeABC):
DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
def encode(self, clip, text):
if clip is None:
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
tokens = clip.tokenize(text)
return (clip.encode_from_tokens_scheduled(tokens), )
@@ -914,7 +912,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@@ -924,7 +922,7 @@ class CLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5\ncosmos: old t5 xxl"
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
if type == "stable_cascade":
@@ -939,12 +937,6 @@ class CLIPLoader:
clip_type = comfy.sd.CLIPType.LTXV
elif type == "pixart":
clip_type = comfy.sd.CLIPType.PIXART
elif type == "cosmos":
clip_type = comfy.sd.CLIPType.COSMOS
elif type == "lumina2":
clip_type = comfy.sd.CLIPType.LUMINA2
elif type == "wan":
clip_type = comfy.sd.CLIPType.WAN
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
@@ -1066,11 +1058,10 @@ class StyleModelApply:
for t in conditioning:
(txt, keys) = t
keys = keys.copy()
# even if the strength is 1.0 (i.e, no change), if there's already a mask, we have to add to it
if "attention_mask" in keys or (strength_type == "attn_bias" and strength != 1.0):
if strength_type == "attn_bias" and strength != 1.0:
# math.log raises an error if the argument is zero
# torch.log returns -inf, which is what we want
attn_bias = torch.log(torch.Tensor([strength if strength_type == "attn_bias" else 1.0]))
attn_bias = torch.log(torch.Tensor([strength]))
# get the size of the mask image
mask_ref_size = keys.get("attention_mask_img_shape", (1, 1))
n_ref = mask_ref_size[0] * mask_ref_size[1]
@@ -1765,36 +1756,6 @@ class LoadImageMask:
return True
class LoadImageOutput(LoadImage):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("COMBO", {
"image_upload": True,
"image_folder": "output",
"remote": {
"route": "/internal/files/output",
"refresh_button": True,
"control_after_refresh": "first",
},
}),
}
}
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
EXPERIMENTAL = True
FUNCTION = "load_image_output"
def load_image_output(self, image):
return self.load_image(f"{image} [output]")
@classmethod
def VALIDATE_INPUTS(s, image):
return True
class ImageScale:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop_methods = ["disabled", "center"]
@@ -1981,7 +1942,6 @@ NODE_CLASS_MAPPINGS = {
"PreviewImage": PreviewImage,
"LoadImage": LoadImage,
"LoadImageMask": LoadImageMask,
"LoadImageOutput": LoadImageOutput,
"ImageScale": ImageScale,
"ImageScaleBy": ImageScaleBy,
"ImageInvert": ImageInvert,
@@ -2082,7 +2042,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"PreviewImage": "Preview Image",
"LoadImage": "Load Image",
"LoadImageMask": "Load Image (as Mask)",
"LoadImageOutput": "Load Image (from Outputs)",
"ImageScale": "Upscale Image",
"ImageScaleBy": "Upscale Image By",
"ImageUpscaleWithModel": "Upscale Image (using Model)",
@@ -2267,9 +2226,6 @@ def init_builtin_extra_nodes():
"nodes_hooks.py",
"nodes_load_3d.py",
"nodes_cosmos.py",
"nodes_video.py",
"nodes_lumina2.py",
"nodes_wan.py",
]
import_failed = []

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.20"
version = "0.3.10"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -1,24 +1,22 @@
comfyui-frontend-package==1.10.17
torch
torchsde
torchvision
torchaudio
numpy>=1.25.0
einops
transformers>=4.28.1
tokenizers>=0.13.3
sentencepiece
safetensors>=0.4.2
aiohttp>=3.11.8
yarl>=1.18.0
aiohttp
pyyaml
Pillow
scipy
tqdm
psutil
alembic
SQLAlchemy
#non essential dependencies:
kornia>=0.7.1
spandrel
soundfile
av

View File

@@ -52,20 +52,6 @@ async def cache_control(request: web.Request, handler):
response.headers.setdefault('Cache-Control', 'no-cache')
return response
@web.middleware
async def compress_body(request: web.Request, handler):
accept_encoding = request.headers.get("Accept-Encoding", "")
response: web.Response = await handler(request)
if not isinstance(response, web.Response):
return response
if response.content_type not in ["application/json", "text/plain"]:
return response
if response.body and "gzip" in accept_encoding:
response.enable_compression()
return response
def create_cors_middleware(allowed_origin: str):
@web.middleware
async def cors_middleware(request: web.Request, handler):
@@ -150,8 +136,7 @@ class PromptServer():
PromptServer.instance = self
mimetypes.init()
mimetypes.add_type('application/javascript; charset=utf-8', '.js')
mimetypes.add_type('image/webp', '.webp')
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
self.user_manager = UserManager()
self.model_file_manager = ModelFileManager()
@@ -165,9 +150,6 @@ class PromptServer():
self.number = 0
middlewares = [cache_control]
if args.enable_compress_response_body:
middlewares.append(compress_body)
if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header))
else:
@@ -347,9 +329,6 @@ class PromptServer():
original_ref = json.loads(post.get("original_ref"))
filename, output_dir = folder_paths.annotated_filepath(original_ref['filename'])
if not filename:
return web.Response(status=400)
# validation for security: prevent accessing arbitrary path
if filename[0] == '/' or '..' in filename:
return web.Response(status=400)
@@ -391,9 +370,6 @@ class PromptServer():
filename = request.rel_url.query["filename"]
filename,output_dir = folder_paths.annotated_filepath(filename)
if not filename:
return web.Response(status=400)
# validation for security: prevent accessing arbitrary path
if filename[0] == '/' or '..' in filename:
return web.Response(status=400)

View File

@@ -2,146 +2,39 @@ import pytest
from aiohttp import web
from unittest.mock import patch
from app.custom_node_manager import CustomNodeManager
import json
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def custom_node_manager():
return CustomNodeManager()
@pytest.fixture
def app(custom_node_manager):
app = web.Application()
routes = web.RouteTableDef()
custom_node_manager.add_routes(
routes, app, [("ComfyUI-TestExtension1", "ComfyUI-TestExtension1")]
)
custom_node_manager.add_routes(routes, app, [("ComfyUI-TestExtension1", "ComfyUI-TestExtension1")])
app.add_routes(routes)
return app
async def test_get_workflow_templates(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
# Setup temporary custom nodes file structure with 1 workflow file
custom_nodes_dir = tmp_path / "custom_nodes"
example_workflows_dir = (
custom_nodes_dir / "ComfyUI-TestExtension1" / "example_workflows"
)
example_workflows_dir = custom_nodes_dir / "ComfyUI-TestExtension1" / "example_workflows"
example_workflows_dir.mkdir(parents=True)
template_file = example_workflows_dir / "workflow1.json"
template_file.write_text("")
template_file.write_text('')
with patch(
"folder_paths.folder_names_and_paths",
{"custom_nodes": ([str(custom_nodes_dir)], None)},
):
response = await client.get("/workflow_templates")
with patch('folder_paths.folder_names_and_paths', {
'custom_nodes': ([str(custom_nodes_dir)], None)
}):
response = await client.get('/workflow_templates')
assert response.status == 200
workflows_dict = await response.json()
assert isinstance(workflows_dict, dict)
assert "ComfyUI-TestExtension1" in workflows_dict
assert isinstance(workflows_dict["ComfyUI-TestExtension1"], list)
assert workflows_dict["ComfyUI-TestExtension1"][0] == "workflow1"
async def test_build_translations_empty_when_no_locales(custom_node_manager, tmp_path):
custom_nodes_dir = tmp_path / "custom_nodes"
custom_nodes_dir.mkdir(parents=True)
with patch("folder_paths.get_folder_paths", return_value=[str(custom_nodes_dir)]):
translations = custom_node_manager.build_translations()
assert translations == {}
async def test_build_translations_loads_all_files(custom_node_manager, tmp_path):
# Setup test directory structure
custom_nodes_dir = tmp_path / "custom_nodes" / "test-extension"
locales_dir = custom_nodes_dir / "locales" / "en"
locales_dir.mkdir(parents=True)
# Create test translation files
main_content = {"title": "Test Extension"}
(locales_dir / "main.json").write_text(json.dumps(main_content))
node_defs = {"node1": "Node 1"}
(locales_dir / "nodeDefs.json").write_text(json.dumps(node_defs))
commands = {"cmd1": "Command 1"}
(locales_dir / "commands.json").write_text(json.dumps(commands))
settings = {"setting1": "Setting 1"}
(locales_dir / "settings.json").write_text(json.dumps(settings))
with patch(
"folder_paths.get_folder_paths", return_value=[tmp_path / "custom_nodes"]
):
translations = custom_node_manager.build_translations()
assert translations == {
"en": {
"title": "Test Extension",
"nodeDefs": {"node1": "Node 1"},
"commands": {"cmd1": "Command 1"},
"settings": {"setting1": "Setting 1"},
}
}
async def test_build_translations_handles_invalid_json(custom_node_manager, tmp_path):
# Setup test directory structure
custom_nodes_dir = tmp_path / "custom_nodes" / "test-extension"
locales_dir = custom_nodes_dir / "locales" / "en"
locales_dir.mkdir(parents=True)
# Create valid main.json
main_content = {"title": "Test Extension"}
(locales_dir / "main.json").write_text(json.dumps(main_content))
# Create invalid JSON file
(locales_dir / "nodeDefs.json").write_text("invalid json{")
with patch(
"folder_paths.get_folder_paths", return_value=[tmp_path / "custom_nodes"]
):
translations = custom_node_manager.build_translations()
assert translations == {
"en": {
"title": "Test Extension",
}
}
async def test_build_translations_merges_multiple_extensions(
custom_node_manager, tmp_path
):
# Setup test directory structure for two extensions
custom_nodes_dir = tmp_path / "custom_nodes"
ext1_dir = custom_nodes_dir / "extension1" / "locales" / "en"
ext2_dir = custom_nodes_dir / "extension2" / "locales" / "en"
ext1_dir.mkdir(parents=True)
ext2_dir.mkdir(parents=True)
# Create translation files for extension 1
ext1_main = {"title": "Extension 1", "shared": "Original"}
(ext1_dir / "main.json").write_text(json.dumps(ext1_main))
# Create translation files for extension 2
ext2_main = {"description": "Extension 2", "shared": "Override"}
(ext2_dir / "main.json").write_text(json.dumps(ext2_main))
with patch("folder_paths.get_folder_paths", return_value=[str(custom_nodes_dir)]):
translations = custom_node_manager.build_translations()
assert translations == {
"en": {
"title": "Extension 1",
"description": "Extension 2",
"shared": "Override", # Second extension should override first
}
}

View File

@@ -7,11 +7,33 @@ from PIL import Image
from aiohttp import web
from unittest.mock import patch
from app.model_manager import ModelFileManager
from app.database.models import Base, Model, Tag
from comfy.cli_args import args
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def session():
# Configure in-memory database
args.database_url = "sqlite:///:memory:"
# Create engine and session factory
engine = create_engine(args.database_url)
Session = sessionmaker(bind=engine)
# Create all tables
Base.metadata.create_all(engine)
# Patch Session factory
with patch('app.database.db.Session', Session):
yield Session()
Base.metadata.drop_all(engine)
@pytest.fixture
def model_manager():
return ModelFileManager()
@@ -60,3 +82,287 @@ async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
# Clean up
img.close()
async def test_get_models(aiohttp_client, app, session):
tag = Tag(name='test_tag')
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
model.tags.append(tag)
session.add(tag)
session.add(model)
session.commit()
client = await aiohttp_client(app)
resp = await client.get('/v2/models')
assert resp.status == 200
data = await resp.json()
assert len(data) == 1
assert data[0]['path'] == 'model1.safetensors'
assert len(data[0]['tags']) == 1
assert data[0]['tags'][0]['name'] == 'test_tag'
async def test_add_model(aiohttp_client, app, session):
tag = Tag(name='test_tag')
session.add(tag)
session.commit()
tag_id = tag.id
with patch('app.model_manager.model_processor') as mock_processor:
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'title': 'Test Model',
'tags': [tag_id]
})
assert resp.status == 200
data = await resp.json()
assert data['path'] == 'model1.safetensors'
assert len(data['tags']) == 1
assert data['tags'][0]['name'] == 'test_tag'
# Ensure that models are re-processed after adding
mock_processor.run.assert_called_once()
async def test_delete_model(aiohttp_client, app, session):
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
with patch('app.model_manager.get_full_path', return_value=None):
client = await aiohttp_client(app)
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
assert resp.status == 204
# Verify model was deleted
model = session.query(Model).first()
assert model is None
async def test_delete_model_file_exists(aiohttp_client, app, session):
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
assert resp.status == 400
data = await resp.json()
assert "file exists" in data["error"].lower()
# Verify model was not deleted
model = session.query(Model).first()
assert model is not None
assert model.path == 'model1.safetensors'
async def test_get_tags(aiohttp_client, app, session):
tags = [Tag(name='tag1'), Tag(name='tag2')]
for tag in tags:
session.add(tag)
session.commit()
client = await aiohttp_client(app)
resp = await client.get('/v2/tags')
assert resp.status == 200
data = await resp.json()
assert len(data) == 2
assert {t['name'] for t in data} == {'tag1', 'tag2'}
async def test_create_tag(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/tags', json={'name': 'new_tag'})
assert resp.status == 200
data = await resp.json()
assert data['name'] == 'new_tag'
# Verify tag was created
tag = session.query(Tag).first()
assert tag.name == 'new_tag'
async def test_delete_tag(aiohttp_client, app, session):
tag = Tag(name='test_tag')
session.add(tag)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.delete(f'/v2/tags?id={tag_id}')
assert resp.status == 204
# Verify tag was deleted
tag = session.query(Tag).first()
assert tag is None
async def test_add_model_tag(aiohttp_client, app, session):
tag = Tag(name='test_tag')
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(tag)
session.add(model)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.post('/v2/models/tags', json={
'tag': tag_id,
'type': 'checkpoints',
'path': 'model1.safetensors'
})
assert resp.status == 200
data = await resp.json()
assert len(data['tags']) == 1
assert data['tags'][0]['name'] == 'test_tag'
async def test_delete_model_tag(aiohttp_client, app, session):
tag = Tag(name='test_tag')
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
model.tags.append(tag)
session.add(tag)
session.add(model)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.delete(f'/v2/models/tags?tag={tag_id}&type=checkpoints&path=model1.safetensors')
assert resp.status == 204
# Verify tag was removed
model = session.query(Model).first()
assert len(model.tags) == 0
async def test_add_model_duplicate(aiohttp_client, app, session):
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'title': 'Duplicate Model'
})
assert resp.status == 400
async def test_add_model_missing_fields(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={})
assert resp.status == 400
async def test_add_tag_missing_name(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/tags', json={})
assert resp.status == 400
async def test_delete_model_not_found(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.delete('/v2/models?type=checkpoints&path=nonexistent.safetensors')
assert resp.status == 404
async def test_delete_tag_not_found(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.delete('/v2/tags?id=999')
assert resp.status == 404
async def test_add_model_missing_path(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'title': 'Test Model'
})
assert resp.status == 400
data = await resp.json()
assert "path" in data["error"].lower()
async def test_add_model_invalid_field(aiohttp_client, app, session):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'invalid_field': 'some value'
})
assert resp.status == 400
data = await resp.json()
assert "invalid field" in data["error"].lower()
async def test_add_model_nonexistent_file(aiohttp_client, app, session):
with patch('app.model_manager.get_full_path', return_value=None):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'nonexistent.safetensors'
})
assert resp.status == 404
data = await resp.json()
assert "file" in data["error"].lower()
async def test_add_model_invalid_tag(aiohttp_client, app, session):
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
client = await aiohttp_client(app)
resp = await client.post('/v2/models', json={
'type': 'checkpoints',
'path': 'model1.safetensors',
'tags': [999] # Non-existent tag ID
})
assert resp.status == 404
data = await resp.json()
assert "tag" in data["error"].lower()
async def test_add_tag_to_nonexistent_model(aiohttp_client, app, session):
# Create a tag but no model
tag = Tag(name='test_tag')
session.add(tag)
session.commit()
tag_id = tag.id
client = await aiohttp_client(app)
resp = await client.post('/v2/models/tags', json={
'tag': tag_id,
'type': 'checkpoints',
'path': 'nonexistent.safetensors'
})
assert resp.status == 404
data = await resp.json()
assert "model" in data["error"].lower()
async def test_delete_model_tag_invalid_tag_id(aiohttp_client, app, session):
# Create a model first
model = Model(
type='checkpoints',
path='model1.safetensors',
title='Test Model'
)
session.add(model)
session.commit()
client = await aiohttp_client(app)
resp = await client.delete('/v2/models/tags?tag=not_a_number&type=checkpoint&path=model1.safetensors')
assert resp.status == 400
data = await resp.json()
assert "invalid tag id" in data["error"].lower()

View File

@@ -1,23 +1,19 @@
### 🗻 This file is created through the spirit of Mount Fuji at its peak
# TODO(yoland): clean up this after I get back down
import sys
import pytest
import os
import tempfile
from unittest.mock import patch
from importlib import reload
import folder_paths
import comfy.cli_args
from comfy.options import enable_args_parsing
enable_args_parsing()
@pytest.fixture()
def clear_folder_paths():
# Reload the module after each test to ensure isolation
# Clear the global dictionary before each test to ensure isolation
original = folder_paths.folder_names_and_paths.copy()
folder_paths.folder_names_and_paths.clear()
yield
reload(folder_paths)
folder_paths.folder_names_and_paths = original
@pytest.fixture
def temp_dir():
@@ -25,21 +21,7 @@ def temp_dir():
yield tmpdirname
@pytest.fixture
def set_base_dir():
def _set_base_dir(base_dir):
# Mock CLI args
with patch.object(sys, 'argv', ["main.py", "--base-directory", base_dir]):
reload(comfy.cli_args)
reload(folder_paths)
yield _set_base_dir
# Reload the modules after each test to ensure isolation
with patch.object(sys, 'argv', ["main.py"]):
reload(comfy.cli_args)
reload(folder_paths)
def test_get_directory_by_type(clear_folder_paths):
def test_get_directory_by_type():
test_dir = "/test/dir"
folder_paths.set_output_directory(test_dir)
assert folder_paths.get_directory_by_type("output") == test_dir
@@ -114,49 +96,3 @@ def test_get_save_image_path(temp_dir):
assert counter == 1
assert subfolder == ""
assert filename_prefix == "test"
def test_base_path_changes(set_base_dir):
test_dir = os.path.abspath("/test/dir")
set_base_dir(test_dir)
assert folder_paths.base_path == test_dir
assert folder_paths.models_dir == os.path.join(test_dir, "models")
assert folder_paths.input_directory == os.path.join(test_dir, "input")
assert folder_paths.output_directory == os.path.join(test_dir, "output")
assert folder_paths.temp_directory == os.path.join(test_dir, "temp")
assert folder_paths.user_directory == os.path.join(test_dir, "user")
assert os.path.join(test_dir, "custom_nodes") in folder_paths.get_folder_paths("custom_nodes")
for name in ["checkpoints", "loras", "vae", "configs", "embeddings", "controlnet", "classifiers"]:
assert folder_paths.get_folder_paths(name)[0] == os.path.join(test_dir, "models", name)
def test_base_path_change_clears_old(set_base_dir):
test_dir = os.path.abspath("/test/dir")
set_base_dir(test_dir)
assert len(folder_paths.get_folder_paths("custom_nodes")) == 1
single_model_paths = [
"checkpoints",
"loras",
"vae",
"configs",
"clip_vision",
"style_models",
"diffusers",
"vae_approx",
"gligen",
"upscale_models",
"embeddings",
"hypernetworks",
"photomaker",
"classifiers",
]
for name in single_model_paths:
assert len(folder_paths.get_folder_paths(name)) == 1
for name in ["controlnet", "diffusion_models", "text_encoders"]:
assert len(folder_paths.get_folder_paths(name)) == 2

Some files were not shown because too many files have changed in this diff Show More