Compare commits
92 Commits
v0.3.33
...
pysssss-mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d5160f92c | ||
|
|
7f7b3f1695 | ||
|
|
9da6aca0d0 | ||
|
|
1cb3c98947 | ||
|
|
d3bd983b91 | ||
|
|
fb4754624d | ||
|
|
180db6753f | ||
|
|
d062fcc5c0 | ||
|
|
456abad834 | ||
|
|
19e45e9b0e | ||
|
|
97f23b81f3 | ||
|
|
08b7cc7506 | ||
|
|
6c319cbb4e | ||
|
|
df1aebe52e | ||
|
|
704fc78854 | ||
|
|
1d9fee79fd | ||
|
|
aeba0b3a26 | ||
|
|
094306b626 | ||
|
|
31260f0275 | ||
|
|
f1c9ca816a | ||
|
|
f2289a1f59 | ||
|
|
fb83eda287 | ||
|
|
5e5e46d40c | ||
|
|
4eba3161cf | ||
|
|
592d056100 | ||
|
|
1c1687ab1c | ||
|
|
e6609dacde | ||
|
|
ba37e67964 | ||
|
|
06c661004e | ||
|
|
c9e1821a7b | ||
|
|
f58f0f5696 | ||
|
|
3a10b9641c | ||
|
|
89a84e32d2 | ||
|
|
e5799c4899 | ||
|
|
a0651359d7 | ||
|
|
ad3bd8aa49 | ||
|
|
5a87757ef9 | ||
|
|
464aece92b | ||
|
|
0b50d4c0db | ||
|
|
30b2eb8a93 | ||
|
|
f85c08df06 | ||
|
|
4202e956a0 | ||
|
|
b838c36720 | ||
|
|
fc39184ea9 | ||
|
|
ded60c33a0 | ||
|
|
8bb858e4d3 | ||
|
|
57893c843f | ||
|
|
65da29aaa9 | ||
|
|
10024a38ea | ||
|
|
87f9130778 | ||
|
|
7e84bf5373 | ||
|
|
4f3b50ba51 | ||
|
|
e930a387d6 | ||
|
|
d8e5662822 | ||
|
|
3d44a09812 | ||
|
|
62690eddec | ||
|
|
05eb10b43a | ||
|
|
f5e4e976f4 | ||
|
|
aee2908d03 | ||
|
|
dc46db7aa4 | ||
|
|
7046983d95 | ||
|
|
1c2d45d2b5 | ||
|
|
c820ef950d | ||
|
|
6a2e4bb9e0 | ||
|
|
f1f9763b4c | ||
|
|
08368f8e00 | ||
|
|
f3ff5c40db | ||
|
|
98ff01e148 | ||
|
|
bab836d88d | ||
|
|
4a9014e201 | ||
|
|
8a7c894d54 | ||
|
|
a814f2e8cc | ||
|
|
481732a0ed | ||
|
|
2156ce9453 | ||
|
|
4136502b7a | ||
|
|
9ad287ff20 | ||
|
|
f5cacaeb14 | ||
|
|
b7ed5f57bd | ||
|
|
b4abca828e | ||
|
|
158419f3a0 | ||
|
|
640c47e7de | ||
|
|
31e9e36c94 | ||
|
|
577de83ca9 | ||
|
|
3535909eb8 | ||
|
|
235d3901fc | ||
|
|
d42613686f | ||
|
|
1b3bf0a5da | ||
|
|
ae60b150e5 | ||
|
|
42da274717 | ||
|
|
28f178a840 | ||
|
|
8ab15c863c | ||
|
|
924d771e18 |
26
CODEOWNERS
26
CODEOWNERS
@@ -5,20 +5,20 @@
|
||||
# Inlined the team members for now.
|
||||
|
||||
# Maintainers
|
||||
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
|
||||
13
README.md
13
README.md
@@ -69,9 +69,11 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- Audio Models
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- 3D Models
|
||||
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
@@ -108,7 +110,6 @@ ComfyUI follows a weekly release cycle every Friday, with three interconnected r
|
||||
|
||||
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
||||
- Builds a new release using the latest stable core version
|
||||
- Version numbers match the core release (e.g., Desktop v1.7.0 uses Core v1.7.0)
|
||||
|
||||
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
|
||||
- Weekly frontend updates are merged into the core repository
|
||||
@@ -196,11 +197,11 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2.4```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.3 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
@@ -300,7 +301,7 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt
|
||||
|
||||
### AMD ROCm Tips
|
||||
|
||||
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
|
||||
You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default.
|
||||
|
||||
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
|
||||
|
||||
|
||||
84
alembic.ini
Normal file
84
alembic.ini
Normal file
@@ -0,0 +1,84 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
# Use forward slashes (/) also on windows to provide an os agnostic path
|
||||
script_location = alembic_db
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to alembic_db/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
# version_path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
version_path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = sqlite:///user/comfyui.db
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
4
alembic_db/README.md
Normal file
4
alembic_db/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
## Generate new revision
|
||||
|
||||
1. Update models in `/app/database/models.py`
|
||||
2. Run `alembic revision --autogenerate -m "{your message}"`
|
||||
69
alembic_db/env.py
Normal file
69
alembic_db/env.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
|
||||
from app.database.models import Base
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
alembic_db/script.py.mako
Normal file
28
alembic_db/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
40
alembic_db/versions/e9c714da8d57_init.py
Normal file
40
alembic_db/versions/e9c714da8d57_init.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""init
|
||||
|
||||
Revision ID: e9c714da8d57
|
||||
Revises:
|
||||
Create Date: 2025-05-30 20:14:33.772039
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'e9c714da8d57'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.create_table('model',
|
||||
sa.Column('type', sa.Text(), nullable=False),
|
||||
sa.Column('path', sa.Text(), nullable=False),
|
||||
sa.Column('file_name', sa.Text(), nullable=True),
|
||||
sa.Column('file_size', sa.Integer(), nullable=True),
|
||||
sa.Column('hash', sa.Text(), nullable=True),
|
||||
sa.Column('hash_algorithm', sa.Text(), nullable=True),
|
||||
sa.Column('source_url', sa.Text(), nullable=True),
|
||||
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.PrimaryKeyConstraint('type', 'path')
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('model')
|
||||
# ### end Alembic commands ###
|
||||
112
app/database/db.py
Normal file
112
app/database/db.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from app.logger import log_startup_warning
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
from comfy.cli_args import args
|
||||
|
||||
_DB_AVAILABLE = False
|
||||
Session = None
|
||||
|
||||
|
||||
try:
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
log_startup_warning(
|
||||
f"""
|
||||
------------------------------------------------------------------------
|
||||
Error importing dependencies: {e}
|
||||
|
||||
{get_missing_requirements_message()}
|
||||
|
||||
This error is happening because ComfyUI now uses a local sqlite database.
|
||||
------------------------------------------------------------------------
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def dependencies_available():
|
||||
"""
|
||||
Temporary function to check if the dependencies are available
|
||||
"""
|
||||
return _DB_AVAILABLE
|
||||
|
||||
|
||||
def can_create_session():
|
||||
"""
|
||||
Temporary function to check if the database is available to create a session
|
||||
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
|
||||
"""
|
||||
return dependencies_available() and Session is not None
|
||||
|
||||
|
||||
def get_alembic_config():
|
||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||
|
||||
config = Config(config_path)
|
||||
config.set_main_option("script_location", scripts_path)
|
||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_db_path():
|
||||
url = args.database_url
|
||||
if url.startswith("sqlite:///"):
|
||||
return url.split("///")[1]
|
||||
else:
|
||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
db_path = get_db_path()
|
||||
db_exists = os.path.exists(db_path)
|
||||
|
||||
config = get_alembic_config()
|
||||
|
||||
# Check if we need to upgrade
|
||||
engine = create_engine(db_url)
|
||||
conn = engine.connect()
|
||||
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
script = ScriptDirectory.from_config(config)
|
||||
target_rev = script.get_current_head()
|
||||
|
||||
if current_rev != target_rev:
|
||||
# Backup the database pre upgrade
|
||||
backup_path = db_path + ".bkp"
|
||||
if db_exists:
|
||||
shutil.copy(db_path, backup_path)
|
||||
else:
|
||||
backup_path = None
|
||||
|
||||
try:
|
||||
command.upgrade(config, target_rev)
|
||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||
except Exception as e:
|
||||
if backup_path:
|
||||
# Restore the database from backup if upgrade fails
|
||||
shutil.copy(backup_path, db_path)
|
||||
os.remove(backup_path)
|
||||
logging.error(f"Error upgrading database: {e}")
|
||||
raise e
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
|
||||
def create_session():
|
||||
return Session()
|
||||
59
app/database/models.py
Normal file
59
app/database/models.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
Text,
|
||||
DateTime,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
fields = obj.__table__.columns.keys()
|
||||
return {
|
||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||
for field in fields
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
|
||||
|
||||
class Model(Base):
|
||||
"""
|
||||
sqlalchemy model representing a model file in the system.
|
||||
|
||||
This class defines the database schema for storing information about model files,
|
||||
including their type, path, hash, and when they were added to the system.
|
||||
|
||||
Attributes:
|
||||
type (Text): The type of the model, this is the name of the folder in the models folder (primary key)
|
||||
path (Text): The file path of the model relative to the type folder (primary key)
|
||||
file_name (Text): The name of the model file
|
||||
file_size (Integer): The size of the model file in bytes
|
||||
hash (Text): A hash of the model file
|
||||
hash_algorithm (Text): The algorithm used to generate the hash
|
||||
source_url (Text): The URL of the model file
|
||||
date_added (DateTime): Timestamp of when the model was added to the system
|
||||
"""
|
||||
|
||||
__tablename__ = "model"
|
||||
|
||||
type = Column(Text, primary_key=True)
|
||||
path = Column(Text, primary_key=True)
|
||||
file_name = Column(Text)
|
||||
file_size = Column(Integer)
|
||||
hash = Column(Text)
|
||||
hash_algorithm = Column(Text)
|
||||
source_url = Column(Text)
|
||||
date_added = Column(DateTime, server_default=func.now())
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Convert the model instance to a dictionary representation.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the attributes of the model
|
||||
"""
|
||||
dict = to_dict(self)
|
||||
return dict
|
||||
@@ -16,26 +16,15 @@ from importlib.metadata import version
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from utils.install_util import get_missing_requirements_message, requirements_path
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
import app.logger
|
||||
|
||||
# The path to the requirements.txt file
|
||||
req_path = Path(__file__).parents[1] / "requirements.txt"
|
||||
|
||||
|
||||
def frontend_install_warning_message():
|
||||
"""The warning message to display when the frontend version is not up to date."""
|
||||
|
||||
extra = ""
|
||||
if sys.flags.no_user_site:
|
||||
extra = "-s "
|
||||
return f"""
|
||||
Please install the updated requirements.txt file by running:
|
||||
{sys.executable} {extra}-m pip install -r {req_path}
|
||||
{get_missing_requirements_message()}
|
||||
|
||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||
|
||||
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
|
||||
""".strip()
|
||||
|
||||
|
||||
@@ -48,7 +37,7 @@ def check_frontend_version():
|
||||
try:
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
frontend_version = parse_version(frontend_version_str)
|
||||
with open(req_path, "r", encoding="utf-8") as f:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||
if frontend_version < required_frontend:
|
||||
app.logger.log_startup_warning(
|
||||
@@ -162,10 +151,30 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
|
||||
|
||||
class FrontendManager:
|
||||
"""
|
||||
A class to manage ComfyUI frontend versions and installations.
|
||||
|
||||
This class handles the initialization and management of different frontend versions,
|
||||
including the default frontend from the pip package and custom frontend versions
|
||||
from GitHub repositories.
|
||||
|
||||
Attributes:
|
||||
CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored.
|
||||
"""
|
||||
|
||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
"""
|
||||
Get the path to the default frontend installation from the pip package.
|
||||
|
||||
Returns:
|
||||
str: The path to the default frontend static files.
|
||||
|
||||
Raises:
|
||||
SystemExit: If the comfyui-frontend-package is not installed.
|
||||
"""
|
||||
try:
|
||||
import comfyui_frontend_package
|
||||
|
||||
@@ -186,6 +195,15 @@ comfyui-frontend-package is not installed.
|
||||
|
||||
@classmethod
|
||||
def templates_path(cls) -> str:
|
||||
"""
|
||||
Get the path to the workflow templates.
|
||||
|
||||
Returns:
|
||||
str: The path to the workflow templates directory.
|
||||
|
||||
Raises:
|
||||
SystemExit: If the comfyui-workflow-templates package is not installed.
|
||||
"""
|
||||
try:
|
||||
import comfyui_workflow_templates
|
||||
|
||||
@@ -205,14 +223,32 @@ comfyui-workflow-templates is not installed.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def embedded_docs_path(cls) -> str:
|
||||
"""Get the path to embedded documentation"""
|
||||
try:
|
||||
import comfyui_embedded_docs
|
||||
|
||||
return str(
|
||||
importlib.resources.files(comfyui_embedded_docs) / "docs"
|
||||
)
|
||||
except ImportError:
|
||||
logging.info("comfyui-embedded-docs package not found")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
Parse a version string into its components.
|
||||
|
||||
The version string should be in the format: 'owner/repo@version'
|
||||
where version can be either a semantic version (v1.2.3) or 'latest'.
|
||||
|
||||
Args:
|
||||
value (str): The version string to parse.
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: A tuple containing provider name and version.
|
||||
tuple[str, str, str]: A tuple containing (owner, repo, version).
|
||||
|
||||
Raises:
|
||||
argparse.ArgumentTypeError: If the version string is invalid.
|
||||
@@ -229,18 +265,22 @@ comfyui-workflow-templates is not installed.
|
||||
cls, version_string: str, provider: Optional[FrontEndProvider] = None
|
||||
) -> str:
|
||||
"""
|
||||
Initializes the frontend for the specified version.
|
||||
Initialize a frontend version without error handling.
|
||||
|
||||
This method attempts to initialize a specific frontend version, either from
|
||||
the default pip package or from a custom GitHub repository. It will download
|
||||
and extract the frontend files if necessary.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string.
|
||||
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
||||
version_string (str): The version string specifying which frontend to use.
|
||||
provider (FrontEndProvider, optional): The provider to use for custom frontends.
|
||||
|
||||
Returns:
|
||||
str: The path to the initialized frontend.
|
||||
|
||||
Raises:
|
||||
Exception: If there is an error during the initialization process.
|
||||
main error source might be request timeout or invalid URL.
|
||||
Exception: If there is an error during initialization (e.g., network timeout,
|
||||
invalid URL, or missing assets).
|
||||
"""
|
||||
if version_string == DEFAULT_VERSION_STRING:
|
||||
check_frontend_version()
|
||||
@@ -292,13 +332,17 @@ comfyui-workflow-templates is not installed.
|
||||
@classmethod
|
||||
def init_frontend(cls, version_string: str) -> str:
|
||||
"""
|
||||
Initializes the frontend with the specified version string.
|
||||
Initialize a frontend version with error handling.
|
||||
|
||||
This is the main method to initialize a frontend version. It wraps init_frontend_unsafe
|
||||
with error handling, falling back to the default frontend if initialization fails.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string to initialize the frontend with.
|
||||
version_string (str): The version string specifying which frontend to use.
|
||||
|
||||
Returns:
|
||||
str: The path of the initialized frontend.
|
||||
str: The path to the initialized frontend. If initialization fails,
|
||||
returns the path to the default frontend.
|
||||
"""
|
||||
try:
|
||||
return cls.init_frontend_unsafe(version_string)
|
||||
|
||||
331
app/model_processor.py
Normal file
331
app/model_processor.py
Normal file
@@ -0,0 +1,331 @@
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from folder_paths import get_relative_path, get_full_path
|
||||
from app.database.db import create_session, dependencies_available, can_create_session
|
||||
import blake3
|
||||
import comfy.utils
|
||||
|
||||
|
||||
if dependencies_available():
|
||||
from app.database.models import Model
|
||||
|
||||
|
||||
class ModelProcessor:
|
||||
def _validate_path(self, model_path):
|
||||
try:
|
||||
if not self._file_exists(model_path):
|
||||
logging.error(f"Model file not found: {model_path}")
|
||||
return None
|
||||
|
||||
result = get_relative_path(model_path)
|
||||
if not result:
|
||||
logging.error(
|
||||
f"Model file not in a recognized model directory: {model_path}"
|
||||
)
|
||||
return None
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error validating model path {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _file_exists(self, path):
|
||||
"""Check if a file exists."""
|
||||
return os.path.exists(path)
|
||||
|
||||
def _get_file_size(self, path):
|
||||
"""Get file size."""
|
||||
return os.path.getsize(path)
|
||||
|
||||
def _get_hasher(self):
|
||||
return blake3.blake3()
|
||||
|
||||
def _hash_file(self, model_path):
|
||||
try:
|
||||
hasher = self._get_hasher()
|
||||
with open(model_path, "rb", buffering=0) as f:
|
||||
b = bytearray(128 * 1024)
|
||||
mv = memoryview(b)
|
||||
while n := f.readinto(mv):
|
||||
hasher.update(mv[:n])
|
||||
return hasher.hexdigest()
|
||||
except Exception as e:
|
||||
logging.error(f"Error hashing file {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _get_existing_model(self, session, model_type, model_relative_path):
|
||||
return (
|
||||
session.query(Model)
|
||||
.filter(Model.type == model_type)
|
||||
.filter(Model.path == model_relative_path)
|
||||
.first()
|
||||
)
|
||||
|
||||
def _ensure_source_url(self, session, model, source_url):
|
||||
if model.source_url is None:
|
||||
model.source_url = source_url
|
||||
session.commit()
|
||||
|
||||
def _update_database(
|
||||
self,
|
||||
session,
|
||||
model_type,
|
||||
model_path,
|
||||
model_relative_path,
|
||||
model_hash,
|
||||
model,
|
||||
source_url,
|
||||
):
|
||||
try:
|
||||
if not model:
|
||||
model = self._get_existing_model(
|
||||
session, model_type, model_relative_path
|
||||
)
|
||||
|
||||
if not model:
|
||||
model = Model(
|
||||
path=model_relative_path,
|
||||
type=model_type,
|
||||
file_name=os.path.basename(model_path),
|
||||
)
|
||||
session.add(model)
|
||||
|
||||
model.file_size = self._get_file_size(model_path)
|
||||
model.hash = model_hash
|
||||
if model_hash:
|
||||
model.hash_algorithm = "blake3"
|
||||
model.source_url = source_url
|
||||
|
||||
session.commit()
|
||||
return model
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error updating database for {model_relative_path}: {str(e)}"
|
||||
)
|
||||
|
||||
def process_file(self, model_path, source_url=None, model_hash=None):
|
||||
"""
|
||||
Process a model file and update the database with metadata.
|
||||
If the file already exists and matches the database, it will not be processed again.
|
||||
Returns the model object or if an error occurs, returns None.
|
||||
"""
|
||||
try:
|
||||
if not can_create_session():
|
||||
return
|
||||
|
||||
result = self._validate_path(model_path)
|
||||
if not result:
|
||||
return
|
||||
model_type, model_relative_path = result
|
||||
|
||||
with create_session() as session:
|
||||
session.expire_on_commit = False
|
||||
|
||||
existing_model = self._get_existing_model(
|
||||
session, model_type, model_relative_path
|
||||
)
|
||||
if (
|
||||
existing_model
|
||||
and existing_model.hash
|
||||
and existing_model.file_size == self._get_file_size(model_path)
|
||||
):
|
||||
# File exists with hash and same size, no need to process
|
||||
self._ensure_source_url(session, existing_model, source_url)
|
||||
return existing_model
|
||||
|
||||
if model_hash:
|
||||
model_hash = model_hash.lower()
|
||||
logging.info(f"Using provided hash: {model_hash}")
|
||||
else:
|
||||
start_time = time.time()
|
||||
logging.info(f"Hashing model {model_relative_path}")
|
||||
model_hash = self._hash_file(model_path)
|
||||
if not model_hash:
|
||||
return
|
||||
logging.info(
|
||||
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
|
||||
)
|
||||
|
||||
return self._update_database(
|
||||
session,
|
||||
model_type,
|
||||
model_path,
|
||||
model_relative_path,
|
||||
model_hash,
|
||||
existing_model,
|
||||
source_url,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing model file {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def retrieve_model_by_hash(self, model_hash, model_type=None, session=None):
|
||||
"""
|
||||
Retrieve a model file from the database by hash and optionally by model type.
|
||||
Returns the model object or None if the model doesnt exist or an error occurs.
|
||||
"""
|
||||
try:
|
||||
if not can_create_session():
|
||||
return
|
||||
|
||||
dispose_session = False
|
||||
|
||||
if session is None:
|
||||
session = create_session()
|
||||
dispose_session = True
|
||||
|
||||
model = session.query(Model).filter(Model.hash == model_hash)
|
||||
if model_type is not None:
|
||||
model = model.filter(Model.type == model_type)
|
||||
return model.first()
|
||||
except Exception as e:
|
||||
logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}")
|
||||
return None
|
||||
finally:
|
||||
if dispose_session:
|
||||
session.close()
|
||||
|
||||
def retrieve_hash(self, model_path, model_type=None):
|
||||
"""
|
||||
Retrieve the hash of a model file from the database.
|
||||
Returns the hash or None if the model doesnt exist or an error occurs.
|
||||
"""
|
||||
try:
|
||||
if not can_create_session():
|
||||
return
|
||||
|
||||
if model_type is not None:
|
||||
result = self._validate_path(model_path)
|
||||
if not result:
|
||||
return None
|
||||
model_type, model_relative_path = result
|
||||
|
||||
with create_session() as session:
|
||||
model = self._get_existing_model(
|
||||
session, model_type, model_relative_path
|
||||
)
|
||||
if model and model.hash:
|
||||
return model.hash
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error retrieving hash for {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _validate_file_extension(self, file_name):
|
||||
"""Validate that the file extension is supported."""
|
||||
extension = os.path.splitext(file_name)[1]
|
||||
if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"):
|
||||
raise ValueError(f"Unsupported unsafe file for download: {file_name}")
|
||||
|
||||
def _check_existing_file(self, model_type, file_name, expected_hash):
|
||||
"""Check if file exists and has correct hash."""
|
||||
destination_path = get_full_path(model_type, file_name, allow_missing=True)
|
||||
if self._file_exists(destination_path):
|
||||
model = self.process_file(destination_path)
|
||||
if model and (expected_hash is None or model.hash == expected_hash):
|
||||
logging.debug(
|
||||
f"File {destination_path} already exists in the database and has the correct hash or no hash was provided."
|
||||
)
|
||||
return destination_path
|
||||
else:
|
||||
raise ValueError(
|
||||
f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again."
|
||||
)
|
||||
return None
|
||||
|
||||
def _check_existing_file_by_hash(self, hash, type, url):
|
||||
"""Check if a file with the given hash exists in the database and on disk."""
|
||||
hash = hash.lower()
|
||||
with create_session() as session:
|
||||
model = self.retrieve_model_by_hash(hash, type, session)
|
||||
if model:
|
||||
existing_path = get_full_path(type, model.path)
|
||||
if existing_path:
|
||||
logging.debug(
|
||||
f"File {model.path} already exists in the database at {existing_path}"
|
||||
)
|
||||
self._ensure_source_url(session, model, url)
|
||||
return existing_path
|
||||
else:
|
||||
logging.debug(
|
||||
f"File {model.path} exists in the database but not on disk"
|
||||
)
|
||||
return None
|
||||
|
||||
def _download_file(self, url, destination_path, hasher):
|
||||
"""Download a file and update the hasher with its contents."""
|
||||
response = requests.get(url, stream=True)
|
||||
logging.info(f"Downloading {url} to {destination_path}")
|
||||
|
||||
with open(destination_path, "wb") as f:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
if total_size > 0:
|
||||
pbar = comfy.utils.ProgressBar(total_size)
|
||||
else:
|
||||
pbar = None
|
||||
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
|
||||
for chunk in response.iter_content(chunk_size=128 * 1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
hasher.update(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
if pbar:
|
||||
pbar.update(len(chunk))
|
||||
|
||||
def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path):
|
||||
"""Verify that the downloaded file has the expected hash."""
|
||||
if expected_hash is not None and calculated_hash != expected_hash:
|
||||
self._remove_file(destination_path)
|
||||
raise ValueError(
|
||||
f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}"
|
||||
)
|
||||
|
||||
def _remove_file(self, file_path):
|
||||
"""Remove a file from disk."""
|
||||
os.remove(file_path)
|
||||
|
||||
def ensure_downloaded(self, type, url, desired_file_name, hash=None):
|
||||
"""
|
||||
Ensure a model file is downloaded and has the correct hash.
|
||||
Returns the path to the downloaded file.
|
||||
"""
|
||||
logging.debug(
|
||||
f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'"
|
||||
)
|
||||
|
||||
# Validate file extension
|
||||
self._validate_file_extension(desired_file_name)
|
||||
|
||||
# Check if file exists with correct hash
|
||||
if hash:
|
||||
existing_path = self._check_existing_file_by_hash(hash, type, url)
|
||||
if existing_path:
|
||||
return existing_path
|
||||
|
||||
# Check if file exists locally
|
||||
destination_path = get_full_path(type, desired_file_name, allow_missing=True)
|
||||
existing_path = self._check_existing_file(type, desired_file_name, hash)
|
||||
if existing_path:
|
||||
return existing_path
|
||||
|
||||
# Download the file
|
||||
hasher = self._get_hasher()
|
||||
self._download_file(url, destination_path, hasher)
|
||||
|
||||
# Verify hash
|
||||
calculated_hash = hasher.hexdigest()
|
||||
self._verify_downloaded_hash(calculated_hash, hash, destination_path)
|
||||
|
||||
# Update database
|
||||
self.process_file(destination_path, url, calculated_hash)
|
||||
|
||||
# TODO: Notify frontend to reload models
|
||||
|
||||
return destination_path
|
||||
|
||||
|
||||
model_processor = ModelProcessor()
|
||||
@@ -88,6 +88,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
|
||||
|
||||
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
|
||||
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
|
||||
|
||||
class LatentPreviewMethod(enum.Enum):
|
||||
NoPreviews = "none"
|
||||
@@ -142,6 +143,8 @@ class PerformanceFeature(enum.Enum):
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||
@@ -200,6 +203,12 @@ parser.add_argument(
|
||||
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
||||
)
|
||||
|
||||
database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
|
||||
@@ -235,7 +235,7 @@ class ComfyNodeABC(ABC):
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
API_NODE: Optional[bool]
|
||||
"""Flags a node as an API node."""
|
||||
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
||||
@@ -24,6 +24,10 @@ class CONDRegular:
|
||||
conds.append(x.cond)
|
||||
return torch.cat(conds)
|
||||
|
||||
def size(self):
|
||||
return list(self.cond.size())
|
||||
|
||||
|
||||
class CONDNoiseShape(CONDRegular):
|
||||
def process_cond(self, batch_size, device, area, **kwargs):
|
||||
data = self.cond
|
||||
@@ -64,6 +68,7 @@ class CONDCrossAttn(CONDRegular):
|
||||
out.append(c)
|
||||
return torch.cat(out)
|
||||
|
||||
|
||||
class CONDConstant(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
@@ -78,3 +83,48 @@ class CONDConstant(CONDRegular):
|
||||
|
||||
def concat(self, others):
|
||||
return self.cond
|
||||
|
||||
def size(self):
|
||||
return [1]
|
||||
|
||||
|
||||
class CONDList(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
out = []
|
||||
for c in self.cond:
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
||||
|
||||
return self._copy_with(out)
|
||||
|
||||
def can_concat(self, other):
|
||||
if len(self.cond) != len(other.cond):
|
||||
return False
|
||||
for i in range(len(self.cond)):
|
||||
if self.cond[i].shape != other.cond[i].shape:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
out = []
|
||||
for i in range(len(self.cond)):
|
||||
o = [self.cond[i]]
|
||||
for x in others:
|
||||
o.append(x.cond[i])
|
||||
out.append(torch.cat(o))
|
||||
|
||||
return out
|
||||
|
||||
def size(self): # hackish implementation to make the mem estimation work
|
||||
o = 0
|
||||
c = 1
|
||||
for c in self.cond:
|
||||
size = c.size()
|
||||
o += math.prod(size)
|
||||
if len(size) > 1:
|
||||
c = size[1]
|
||||
|
||||
return [1, c, o // c]
|
||||
|
||||
@@ -1277,6 +1277,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
phi1_fn = lambda t: torch.expm1(t) / t
|
||||
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
|
||||
|
||||
old_sigma_down = None
|
||||
old_denoised = None
|
||||
uncond_denoised = None
|
||||
def post_cfg_function(args):
|
||||
@@ -1304,9 +1305,9 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Second order multistep method in https://arxiv.org/pdf/2308.02157
|
||||
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1])
|
||||
t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
|
||||
h = t_next - t
|
||||
c2 = (t_prev - t) / h
|
||||
c2 = (t_prev - t_old) / h
|
||||
|
||||
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
|
||||
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
|
||||
@@ -1326,6 +1327,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
old_denoised = uncond_denoised
|
||||
else:
|
||||
old_denoised = denoised
|
||||
old_sigma_down = sigma_down
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
@@ -326,10 +327,6 @@ class CustomerAttnProcessor2_0:
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def apply_rotary_emb(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -435,13 +432,9 @@ class CustomerAttnProcessor2_0:
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
hidden_states = optimized_attention(
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
||||
).to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
|
||||
@@ -8,11 +8,7 @@ from typing import Callable, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
|
||||
# from diffusers.models.modeling_utils import ModelMixin
|
||||
# from diffusers.loaders import FromOriginalModelMixin
|
||||
# from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
|
||||
from .music_log_mel import LogMelSpectrogram
|
||||
|
||||
@@ -259,7 +255,7 @@ class ResBlock1(torch.nn.Module):
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@@ -269,7 +265,7 @@ class ResBlock1(torch.nn.Module):
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@@ -279,7 +275,7 @@ class ResBlock1(torch.nn.Module):
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@@ -294,7 +290,7 @@ class ResBlock1(torch.nn.Module):
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@@ -304,7 +300,7 @@ class ResBlock1(torch.nn.Module):
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@@ -314,7 +310,7 @@ class ResBlock1(torch.nn.Module):
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@@ -366,7 +362,7 @@ class HiFiGANGenerator(nn.Module):
|
||||
prod(upsample_rates) == hop_length
|
||||
), f"hop_length must be {prod(upsample_rates)}"
|
||||
|
||||
self.conv_pre = weight_norm(
|
||||
self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
num_mels,
|
||||
upsample_initial_channel,
|
||||
@@ -386,7 +382,7 @@ class HiFiGANGenerator(nn.Module):
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
@@ -421,7 +417,7 @@ class HiFiGANGenerator(nn.Module):
|
||||
self.resblocks.append(ResBlock1(ch, k, d))
|
||||
|
||||
self.activation_post = post_activation()
|
||||
self.conv_post = weight_norm(
|
||||
self.conv_post = torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
ch,
|
||||
1,
|
||||
|
||||
@@ -75,16 +75,10 @@ class SnakeBeta(nn.Module):
|
||||
return x
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||
if activation == "elu":
|
||||
|
||||
@@ -80,15 +80,13 @@ class DoubleStreamBlock(nn.Module):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
@@ -102,12 +100,12 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
||||
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
@@ -152,7 +150,7 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
@@ -162,7 +160,7 @@ class SingleStreamBlock(nn.Module):
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
x.addcmul_(mod.gate, output)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
@@ -178,6 +176,6 @@ class LastLayer(nn.Module):
|
||||
shift, scale = vec
|
||||
shift = shift.squeeze(1)
|
||||
scale = scale.squeeze(1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
@@ -163,7 +163,7 @@ class Chroma(nn.Module):
|
||||
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
|
||||
|
||||
# get all modulation index
|
||||
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype)
|
||||
modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype)
|
||||
# we need to broadcast the modulation index here so each batch has all of the index
|
||||
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
|
||||
# and we need to broadcast timestep and guidance along too
|
||||
|
||||
@@ -228,6 +228,7 @@ class HunyuanVideo(nn.Module):
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
@@ -238,6 +239,14 @@ class HunyuanVideo(nn.Module):
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||
|
||||
if ref_latent is not None:
|
||||
ref_latent_ids = self.img_ids(ref_latent)
|
||||
ref_latent = self.img_in(ref_latent)
|
||||
img = torch.cat([ref_latent, img], dim=-2)
|
||||
ref_latent_ids[..., 0] = -1
|
||||
ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1])
|
||||
img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2)
|
||||
|
||||
if guiding_frame_index is not None:
|
||||
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
@@ -313,6 +322,8 @@ class HunyuanVideo(nn.Module):
|
||||
img[:, : img_len] += add
|
||||
|
||||
img = img[:, : img_len]
|
||||
if ref_latent is not None:
|
||||
img = img[:, ref_latent.shape[1]:]
|
||||
|
||||
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
@@ -324,7 +335,7 @@ class HunyuanVideo(nn.Module):
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
|
||||
def img_ids(self, x):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
@@ -334,7 +345,11 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
|
||||
@@ -20,8 +20,11 @@ if model_management.xformers_enabled():
|
||||
if model_management.sage_attention_enabled():
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except ModuleNotFoundError:
|
||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||
except ModuleNotFoundError as e:
|
||||
if e.name == "sageattention":
|
||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||
else:
|
||||
raise e
|
||||
exit(-1)
|
||||
|
||||
if model_management.flash_attention_enabled():
|
||||
|
||||
@@ -247,6 +247,60 @@ class VaceWanAttentionBlock(WanAttentionBlock):
|
||||
return c_skip, c
|
||||
|
||||
|
||||
class WanCamAdapter(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}):
|
||||
super(WanCamAdapter, self).__init__()
|
||||
|
||||
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
|
||||
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
|
||||
|
||||
# Convolution: reduce spatial dimensions by a factor
|
||||
# of 2 (without overlap)
|
||||
self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
# Residual blocks for feature extraction
|
||||
self.residual_blocks = nn.Sequential(
|
||||
*[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Reshape to merge the frame dimension into batch
|
||||
bs, c, f, h, w = x.size()
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
|
||||
|
||||
# Pixel Unshuffle operation
|
||||
x_unshuffled = self.pixel_unshuffle(x)
|
||||
|
||||
# Convolution operation
|
||||
x_conv = self.conv(x_unshuffled)
|
||||
|
||||
# Feature extraction with residual blocks
|
||||
out = self.residual_blocks(x_conv)
|
||||
|
||||
# Reshape to restore original bf dimension
|
||||
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
|
||||
|
||||
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class WanCamResidualBlock(nn.Module):
|
||||
def __init__(self, dim, operation_settings={}):
|
||||
super(WanCamResidualBlock, self).__init__()
|
||||
self.conv1 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.relu(self.conv1(x))
|
||||
out = self.conv2(out)
|
||||
out += residual
|
||||
return out
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
|
||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
||||
@@ -485,13 +539,20 @@ class WanModel(torch.nn.Module):
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
|
||||
if time_dim_concat is not None:
|
||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
|
||||
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
@@ -581,7 +642,7 @@ class VaceWanModel(WanModel):
|
||||
t,
|
||||
context,
|
||||
vace_context,
|
||||
vace_strength=1.0,
|
||||
vace_strength,
|
||||
clip_fea=None,
|
||||
freqs=None,
|
||||
transformer_options={},
|
||||
@@ -607,8 +668,11 @@ class VaceWanModel(WanModel):
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
orig_shape = list(vace_context.shape)
|
||||
vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:])
|
||||
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
|
||||
c = c.flatten(2).transpose(1, 2)
|
||||
c = list(c.split(orig_shape[0], dim=0))
|
||||
|
||||
# arguments
|
||||
x_orig = x
|
||||
@@ -628,8 +692,9 @@ class VaceWanModel(WanModel):
|
||||
|
||||
ii = self.vace_layers_mapping.get(i, None)
|
||||
if ii is not None:
|
||||
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
x += c_skip * vace_strength
|
||||
for iii in range(len(c)):
|
||||
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
x += c_skip * vace_strength[iii]
|
||||
del c_skip
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
@@ -637,3 +702,92 @@ class VaceWanModel(WanModel):
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
class CameraWanModel(WanModel):
|
||||
r"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='camera',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
flf_pos_embed_token_number=None,
|
||||
image_model=None,
|
||||
in_dim_control_adapter=24,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
|
||||
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
clip_fea=None,
|
||||
freqs=None,
|
||||
camera_conditions = None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if self.control_adapter is not None and camera_conditions is not None:
|
||||
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
|
||||
x = x + x_camera
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None:
|
||||
if self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
@@ -283,8 +283,15 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
||||
key_map["transformer.{}".format(key_lora)] = k #SimpleTuner regular format
|
||||
|
||||
if isinstance(model, comfy.model_base.ACEStep):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
@@ -102,6 +102,13 @@ def model_sampling(model_config, model_type):
|
||||
return ModelSampling(model_config)
|
||||
|
||||
|
||||
def convert_tensor(extra, dtype):
|
||||
if hasattr(extra, "dtype"):
|
||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||
extra = extra.to(dtype)
|
||||
return extra
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
||||
super().__init__()
|
||||
@@ -135,6 +142,7 @@ class BaseModel(torch.nn.Module):
|
||||
logging.info("model_type {}".format(model_type.name))
|
||||
logging.debug("adm {}".format(self.adm_channels))
|
||||
self.memory_usage_factor = model_config.memory_usage_factor
|
||||
self.memory_usage_factor_conds = ()
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
@@ -164,9 +172,14 @@ class BaseModel(torch.nn.Module):
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra = kwargs[o]
|
||||
|
||||
if hasattr(extra, "dtype"):
|
||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||
extra = extra.to(dtype)
|
||||
extra = convert_tensor(extra, dtype)
|
||||
elif isinstance(extra, list):
|
||||
ex = []
|
||||
for ext in extra:
|
||||
ex.append(convert_tensor(ext, dtype))
|
||||
extra = ex
|
||||
extra_conds[o] = extra
|
||||
|
||||
t = self.process_timestep(t, x=x, **extra_conds)
|
||||
@@ -325,19 +338,28 @@ class BaseModel(torch.nn.Module):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
def memory_required(self, input_shape, cond_shapes={}):
|
||||
input_shapes = [input_shape]
|
||||
for c in self.memory_usage_factor_conds:
|
||||
shape = cond_shapes.get(c, None)
|
||||
if shape is not None and len(shape) > 0:
|
||||
input_shapes += shape
|
||||
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
dtype = self.get_dtype()
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
#TODO: this needs to be tweaked
|
||||
area = input_shape[0] * math.prod(input_shape[2:])
|
||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
||||
else:
|
||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||
area = input_shape[0] * math.prod(input_shape[2:])
|
||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
return {}
|
||||
|
||||
|
||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
|
||||
adm_inputs = []
|
||||
@@ -924,6 +946,10 @@ class HunyuanVideo(BaseModel):
|
||||
if guiding_frame_index is not None:
|
||||
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
|
||||
|
||||
ref_latent = kwargs.get("ref_latent", None)
|
||||
if ref_latent is not None:
|
||||
out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent))
|
||||
|
||||
return out
|
||||
|
||||
def scale_latent_inpaint(self, latent_image, **kwargs):
|
||||
@@ -1043,6 +1069,11 @@ class WAN21(BaseModel):
|
||||
clip_vision_output = kwargs.get("clip_vision_output", None)
|
||||
if clip_vision_output is not None:
|
||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
||||
|
||||
time_dim_concat = kwargs.get("time_dim_concat", None)
|
||||
if time_dim_concat is not None:
|
||||
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -1058,23 +1089,39 @@ class WAN21_Vace(WAN21):
|
||||
vace_frames = kwargs.get("vace_frames", None)
|
||||
if vace_frames is None:
|
||||
noise_shape[1] = 32
|
||||
vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||
|
||||
for i in range(0, vace_frames.shape[1], 16):
|
||||
vace_frames = vace_frames.clone()
|
||||
vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16])
|
||||
vace_frames = [torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)]
|
||||
|
||||
mask = kwargs.get("vace_mask", None)
|
||||
if mask is None:
|
||||
noise_shape[1] = 64
|
||||
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||
mask = [torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)] * len(vace_frames)
|
||||
|
||||
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
|
||||
vace_frames_out = []
|
||||
for j in range(len(vace_frames)):
|
||||
vf = vace_frames[j].clone()
|
||||
for i in range(0, vf.shape[1], 16):
|
||||
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
|
||||
vf = torch.cat([vf, mask[j]], dim=1)
|
||||
vace_frames_out.append(vf)
|
||||
|
||||
vace_strength = kwargs.get("vace_strength", 1.0)
|
||||
vace_frames = torch.stack(vace_frames_out, dim=1)
|
||||
out['vace_context'] = comfy.conds.CONDRegular(vace_frames)
|
||||
|
||||
vace_strength = kwargs.get("vace_strength", [1.0] * len(vace_frames_out))
|
||||
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||
return out
|
||||
|
||||
class WAN21_Camera(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
camera_conditions = kwargs.get("camera_conditions", None)
|
||||
if camera_conditions is not None:
|
||||
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
|
||||
@@ -361,6 +361,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "vace"
|
||||
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "camera"
|
||||
else:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
@@ -618,6 +620,9 @@ def convert_config(unet_config):
|
||||
|
||||
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
if "conv_in.weight" not in state_dict:
|
||||
return None
|
||||
|
||||
match = {}
|
||||
transformer_depth = []
|
||||
|
||||
|
||||
@@ -297,11 +297,16 @@ except:
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
try:
|
||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||
except:
|
||||
rocm_version = (6, -1)
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx1100", "gfx1101"]): # TODO: more arches
|
||||
if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
except:
|
||||
pass
|
||||
@@ -695,7 +700,7 @@ def unet_inital_load_device(parameters, dtype):
|
||||
return torch_dev
|
||||
|
||||
cpu_dev = torch.device("cpu")
|
||||
if DISABLE_SMART_MEMORY:
|
||||
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
|
||||
return cpu_dev
|
||||
|
||||
model_size = dtype_size(dtype) * parameters
|
||||
@@ -1257,6 +1262,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
def supports_fp8_compute(device=None):
|
||||
if args.supports_fp8_compute:
|
||||
return True
|
||||
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
|
||||
@@ -308,10 +308,10 @@ def fp8_linear(self, input):
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
|
||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
|
||||
@@ -30,7 +30,7 @@ if RMSNorm is None:
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape,
|
||||
eps=None,
|
||||
eps=1e-6,
|
||||
elementwise_affine=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
import math
|
||||
import collections
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.utils
|
||||
@@ -104,6 +106,21 @@ def cleanup_additional_models(models):
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
|
||||
def estimate_memory(model, noise_shape, conds):
|
||||
cond_shapes = collections.defaultdict(list)
|
||||
cond_shapes_min = {}
|
||||
for _, cs in conds.items():
|
||||
for cond in cs:
|
||||
for k, v in model.model.extra_conds_shapes(**cond).items():
|
||||
cond_shapes[k].append(v)
|
||||
if cond_shapes_min.get(k, None) is None:
|
||||
cond_shapes_min[k] = [v]
|
||||
elif math.prod(v) > math.prod(cond_shapes_min[k][0]):
|
||||
cond_shapes_min[k] = [v]
|
||||
|
||||
memory_required = model.model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:]), cond_shapes=cond_shapes)
|
||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||
return memory_required, minimum_memory_required
|
||||
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
@@ -117,9 +134,8 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
@@ -256,7 +256,13 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
cond_shapes = collections.defaultdict(list)
|
||||
for tt in batch_amount:
|
||||
cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()}
|
||||
for k, v in to_run[tt][0].conditioning.items():
|
||||
cond_shapes[k].append(v.size())
|
||||
|
||||
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
|
||||
@@ -451,7 +451,7 @@ class VAE:
|
||||
self.latent_dim = 2
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.disable_offload = True
|
||||
self.extra_1d_channel = 16
|
||||
else:
|
||||
|
||||
@@ -992,6 +992,16 @@ class WAN21_FunControl2V(WAN21_T2V):
|
||||
out = model_base.WAN21(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN21_Camera(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "camera",
|
||||
"in_dim": 32,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
|
||||
return out
|
||||
class WAN21_Vace(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@@ -1129,6 +1139,6 @@ class ACEStep(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
{
|
||||
"_name_or_path": "openai/clip-vit-large-patch14",
|
||||
"architectures": [
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 49407,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 248,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.24.0",
|
||||
"vocab_size": 49408
|
||||
}
|
||||
@@ -28,6 +28,9 @@ import logging
|
||||
import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
|
||||
ALWAYS_SAFE_LOAD = False
|
||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||
@@ -46,10 +49,16 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
||||
else:
|
||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||
|
||||
def is_html_file(file_path):
|
||||
with open(file_path, "rb") as f:
|
||||
content = f.read(100)
|
||||
return b"<!DOCTYPE html>" in content or b"<html" in content
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
metadata = None
|
||||
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
try:
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
@@ -59,6 +68,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
except Exception as e:
|
||||
if is_html_file(ckpt):
|
||||
raise ValueError("{}\n\nFile path: {}\n\nThe requested file is an HTML document not a safetensors file. Please re-download the file, not the web page.".format(e, ckpt))
|
||||
if len(e.args) > 0:
|
||||
message = e.args[0]
|
||||
if "HeaderTooLarge" in message:
|
||||
@@ -67,12 +78,14 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
|
||||
raise e
|
||||
else:
|
||||
torch_args = {}
|
||||
if MMAP_TORCH_FILES:
|
||||
torch_args["mmap"] = True
|
||||
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||
if "global_step" in pl_sd:
|
||||
logging.debug(f"Global Step: {pl_sd['global_step']}")
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
@@ -83,6 +96,13 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
sd = pl_sd
|
||||
else:
|
||||
sd = pl_sd
|
||||
|
||||
try:
|
||||
from app.model_processor import model_processor
|
||||
model_processor.process_file(ckpt)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing file {ckpt}: {e}")
|
||||
|
||||
return (sd, metadata) if return_metadata else sd
|
||||
|
||||
def save_torch_file(sd, ckpt, metadata=None):
|
||||
|
||||
@@ -43,3 +43,13 @@ class VideoInput(ABC):
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
components = self.get_components()
|
||||
frame_count = components.images.shape[0]
|
||||
return float(frame_count / components.frame_rate)
|
||||
|
||||
@@ -80,6 +80,38 @@ class VideoFromFile(VideoInput):
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
if container.duration is not None:
|
||||
return float(container.duration / av.time_base)
|
||||
|
||||
# Fallback: calculate from frame count and frame rate
|
||||
video_stream = next(
|
||||
(s for s in container.streams if s.type == "video"), None
|
||||
)
|
||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||
return float(video_stream.frames / video_stream.average_rate)
|
||||
|
||||
# Last resort: decode frames to count them
|
||||
if video_stream and video_stream.average_rate:
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
# Get video frames
|
||||
frames = []
|
||||
|
||||
5
comfy_api/torch_helpers/__init__.py
Normal file
5
comfy_api/torch_helpers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .torch_compile import set_torch_compile_wrapper
|
||||
|
||||
__all__ = [
|
||||
"set_torch_compile_wrapper",
|
||||
]
|
||||
69
comfy_api/torch_helpers/torch_compile.py
Normal file
69
comfy_api/torch_helpers/torch_compile.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
from comfy.patcher_extension import WrappersMP
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.patcher_extension import WrapperExecutor
|
||||
|
||||
|
||||
COMPILE_KEY = "torch.compile"
|
||||
TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
|
||||
|
||||
|
||||
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable:
|
||||
'''
|
||||
Create a wrapper that will refer to the compiled_diffusion_model.
|
||||
'''
|
||||
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs):
|
||||
try:
|
||||
orig_modules = {}
|
||||
for key, value in compiled_module_dict.items():
|
||||
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
|
||||
comfy.utils.set_attr(executor.class_obj, key, value)
|
||||
return executor(*args, **kwargs)
|
||||
finally:
|
||||
for key, value in orig_modules.items():
|
||||
comfy.utils.set_attr(executor.class_obj, key, value)
|
||||
return apply_torch_compile_wrapper
|
||||
|
||||
|
||||
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None,
|
||||
mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None,
|
||||
keys: list[str]=["diffusion_model"], *args, **kwargs):
|
||||
'''
|
||||
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance.
|
||||
|
||||
When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model.
|
||||
When a list of keys is provided, it will perform torch.compile on only the selected modules.
|
||||
'''
|
||||
# clear out any other torch.compile wrappers
|
||||
model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY)
|
||||
# if no keys, default to 'diffusion_model'
|
||||
if not keys:
|
||||
keys = ["diffusion_model"]
|
||||
# create kwargs dict that can be referenced later
|
||||
compile_kwargs = {
|
||||
"backend": backend,
|
||||
"options": options,
|
||||
"mode": mode,
|
||||
"fullgraph": fullgraph,
|
||||
"dynamic": dynamic,
|
||||
}
|
||||
# get a dict of compiled keys
|
||||
compiled_modules = {}
|
||||
for key in keys:
|
||||
compiled_modules[key] = torch.compile(
|
||||
model=model.get_model_object(key),
|
||||
**compile_kwargs,
|
||||
)
|
||||
# add torch.compile wrapper
|
||||
wrapper_func = apply_torch_compile_factory(
|
||||
compiled_module_dict=compiled_modules,
|
||||
)
|
||||
# store wrapper to run on BaseModel's apply_model function
|
||||
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)
|
||||
# keep compile kwargs for reference
|
||||
model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs
|
||||
@@ -18,6 +18,8 @@ Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to
|
||||
python run main.py --comfy-api-base https://stagingapi.comfy.org
|
||||
```
|
||||
|
||||
To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging.
|
||||
|
||||
API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
|
||||
|
||||
### Redocly Instructions
|
||||
@@ -28,7 +30,7 @@ When developing locally, use the `redocly-dev.yaml` file to generate pydantic mo
|
||||
Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
|
||||
|
||||
```bash
|
||||
# Download the OpenAPI file from prod server.
|
||||
# Download the OpenAPI file from staging server.
|
||||
curl -o openapi.yaml https://stagingapi.comfy.org/openapi
|
||||
|
||||
# Filter out unneeded API definitions.
|
||||
@@ -39,3 +41,25 @@ redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_no
|
||||
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
|
||||
|
||||
```
|
||||
|
||||
|
||||
# Merging to Master
|
||||
|
||||
Before merging to comfyanonymous/ComfyUI master, follow these steps:
|
||||
|
||||
1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes.
|
||||
1. Make sure the ComfyUI API is deployed to prod with your changes.
|
||||
1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file.
|
||||
|
||||
```bash
|
||||
# Download the OpenAPI file from prod server.
|
||||
curl -o openapi.yaml https://api.comfy.org/openapi
|
||||
|
||||
# Filter out unneeded API definitions.
|
||||
npm install -g @redocly/cli
|
||||
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
|
||||
|
||||
# Generate the pydantic datamodels for validation.
|
||||
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
|
||||
|
||||
```
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional
|
||||
import mimetypes
|
||||
from typing import Optional, Union
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.util import VideoContainer, VideoCodec
|
||||
@@ -14,6 +16,7 @@ from comfy_api_nodes.apis.client import (
|
||||
UploadRequest,
|
||||
UploadResponse,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
import numpy as np
|
||||
@@ -59,7 +62,9 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||
return s
|
||||
|
||||
|
||||
def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
|
||||
def validate_and_cast_response(
|
||||
response, timeout: int = None, node_id: Union[str, None] = None
|
||||
) -> torch.Tensor:
|
||||
"""Validates and casts a response to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
@@ -93,6 +98,10 @@ def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
|
||||
elif image_url:
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {image_url}", node_id
|
||||
)
|
||||
img_response = requests.get(image_url, timeout=timeout)
|
||||
if img_response.status_code != 200:
|
||||
raise ValueError("Failed to download the image")
|
||||
@@ -206,6 +215,7 @@ def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||
image_bytesio = download_url_to_bytesio(url, timeout)
|
||||
return bytesio_to_image_tensor(image_bytesio)
|
||||
|
||||
|
||||
def process_image_response(response: requests.Response) -> torch.Tensor:
|
||||
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
||||
return bytesio_to_image_tensor(BytesIO(response.content))
|
||||
@@ -310,11 +320,27 @@ def tensor_to_data_uri(
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||
"""Converts a text file to a base64 string."""
|
||||
with open(filepath, "rb") as f:
|
||||
file_content = f.read()
|
||||
return base64.b64encode(file_content).decode("utf-8")
|
||||
|
||||
|
||||
def text_filepath_to_data_uri(filepath: str) -> str:
|
||||
"""Converts a text file to a data URI."""
|
||||
base64_string = text_filepath_to_base64_string(filepath)
|
||||
mime_type, _ = mimetypes.guess_type(filepath)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def upload_file_to_comfyapi(
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: str,
|
||||
auth_token: Optional[str] = None,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single file to ComfyUI API and returns its download URL.
|
||||
@@ -323,7 +349,7 @@ def upload_file_to_comfyapi(
|
||||
file_bytes_io: BytesIO object containing the file data.
|
||||
filename: The filename of the file.
|
||||
upload_mime_type: MIME type of the file.
|
||||
auth_token: Optional authentication token.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded file.
|
||||
@@ -337,7 +363,7 @@ def upload_file_to_comfyapi(
|
||||
response_model=UploadResponse,
|
||||
),
|
||||
request=request_object,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
response: UploadResponse = operation.execute()
|
||||
@@ -349,9 +375,33 @@ def upload_file_to_comfyapi(
|
||||
return response.download_url
|
||||
|
||||
|
||||
def video_to_base64_string(
|
||||
video: VideoInput,
|
||||
container_format: VideoContainer = None,
|
||||
codec: VideoCodec = None
|
||||
) -> str:
|
||||
"""
|
||||
Converts a video input to a base64 string.
|
||||
|
||||
Args:
|
||||
video: The video input to convert
|
||||
container_format: Optional container format to use (defaults to video.container if available)
|
||||
codec: Optional codec to use (defaults to video.codec if available)
|
||||
"""
|
||||
video_bytes_io = io.BytesIO()
|
||||
|
||||
# Use provided format/codec if specified, otherwise use video's own if available
|
||||
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
||||
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
||||
|
||||
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
||||
video_bytes_io.seek(0)
|
||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def upload_video_to_comfyapi(
|
||||
video: VideoInput,
|
||||
auth_token: Optional[str] = None,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
codec: VideoCodec = VideoCodec.H264,
|
||||
max_duration: Optional[int] = None,
|
||||
@@ -362,7 +412,7 @@ def upload_video_to_comfyapi(
|
||||
|
||||
Args:
|
||||
video: VideoInput object (Comfy VIDEO type).
|
||||
auth_token: Optional authentication token.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
container: The video container format to use (default: MP4).
|
||||
codec: The video codec to use (default: H264).
|
||||
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
|
||||
@@ -390,7 +440,7 @@ def upload_video_to_comfyapi(
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return upload_file_to_comfyapi(
|
||||
video_bytes_io, filename, upload_mime_type, auth_token
|
||||
video_bytes_io, filename, upload_mime_type, auth_kwargs
|
||||
)
|
||||
|
||||
|
||||
@@ -453,7 +503,7 @@ def audio_ndarray_to_bytesio(
|
||||
|
||||
def upload_audio_to_comfyapi(
|
||||
audio: AudioInput,
|
||||
auth_token: Optional[str] = None,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
mime_type: str = "audio/mp4",
|
||||
@@ -465,7 +515,7 @@ def upload_audio_to_comfyapi(
|
||||
|
||||
Args:
|
||||
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
|
||||
auth_token: Optional authentication token.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded audio file.
|
||||
@@ -477,11 +527,28 @@ def upload_audio_to_comfyapi(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
|
||||
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_token)
|
||||
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def audio_to_base64_string(
|
||||
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
|
||||
) -> str:
|
||||
"""Converts an audio input to a base64 string."""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
audio_bytes = audio_bytes_io.getvalue()
|
||||
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||
|
||||
|
||||
def upload_images_to_comfyapi(
|
||||
image: torch.Tensor, max_images=8, auth_token=None, mime_type: Optional[str] = None
|
||||
image: torch.Tensor,
|
||||
max_images=8,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
@@ -490,7 +557,7 @@ def upload_images_to_comfyapi(
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
max_images: Maximum number of images to upload.
|
||||
auth_token: Optional authentication token.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
mime_type: Optional MIME type for the image.
|
||||
"""
|
||||
# if batch, try to upload each file if max_images is greater than 0
|
||||
@@ -521,7 +588,7 @@ def upload_images_to_comfyapi(
|
||||
response_model=UploadResponse,
|
||||
),
|
||||
request=request_object,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response = operation.execute()
|
||||
|
||||
@@ -546,17 +613,24 @@ def upload_images_to_comfyapi(
|
||||
return download_urls
|
||||
|
||||
|
||||
def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor,
|
||||
upscale_method="nearest-exact", crop="disabled",
|
||||
allow_gradient=True, add_channel_dim=False):
|
||||
def resize_mask_to_image(
|
||||
mask: torch.Tensor,
|
||||
image: torch.Tensor,
|
||||
upscale_method="nearest-exact",
|
||||
crop="disabled",
|
||||
allow_gradient=True,
|
||||
add_channel_dim=False,
|
||||
):
|
||||
"""
|
||||
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
|
||||
"""
|
||||
_, H, W, _ = image.shape
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = mask.movedim(-1,1)
|
||||
mask = common_upscale(mask, width=W, height=H, upscale_method=upscale_method, crop=crop)
|
||||
mask = mask.movedim(1,-1)
|
||||
mask = mask.movedim(-1, 1)
|
||||
mask = common_upscale(
|
||||
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
|
||||
)
|
||||
mask = mask.movedim(1, -1)
|
||||
if not add_channel_dim:
|
||||
mask = mask.squeeze(-1)
|
||||
if not allow_gradient:
|
||||
@@ -564,12 +638,41 @@ def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor,
|
||||
return mask
|
||||
|
||||
|
||||
def validate_string(string: str, strip_whitespace=True, field_name="prompt", min_length=None, max_length=None):
|
||||
def validate_string(
|
||||
string: str,
|
||||
strip_whitespace=True,
|
||||
field_name="prompt",
|
||||
min_length=None,
|
||||
max_length=None,
|
||||
):
|
||||
if string is None:
|
||||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||||
if strip_whitespace:
|
||||
string = string.strip()
|
||||
if min_length and len(string) < min_length:
|
||||
raise Exception(f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long.")
|
||||
raise Exception(
|
||||
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
if max_length and len(string) > max_length:
|
||||
raise Exception(f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long.")
|
||||
if not string:
|
||||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||||
raise Exception(
|
||||
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
|
||||
|
||||
def image_tensor_pair_to_batch(
|
||||
image1: torch.Tensor, image2: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts a pair of image tensors to a batch tensor.
|
||||
If the images are not the same size, the smaller image is resized to
|
||||
match the larger image.
|
||||
"""
|
||||
if image1.shape[1:] != image2.shape[1:]:
|
||||
image2 = common_upscale(
|
||||
image2.movedim(-1, 1),
|
||||
image1.shape[2],
|
||||
image1.shape[1],
|
||||
"bilinear",
|
||||
"center",
|
||||
).movedim(1, -1)
|
||||
return torch.cat((image1, image2), dim=0)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -108,6 +108,24 @@ class BFLFluxProGenerateRequest(BaseModel):
|
||||
# )
|
||||
|
||||
|
||||
class BFLFluxKontextProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
|
||||
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
|
||||
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
|
||||
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
|
||||
|
||||
class BFLFluxProUltraGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
|
||||
@@ -20,7 +20,8 @@ Usage Examples:
|
||||
# 1. Create the API client
|
||||
api_client = ApiClient(
|
||||
base_url="https://api.example.com",
|
||||
api_key="your_api_key_here",
|
||||
auth_token="your_auth_token_here",
|
||||
comfy_api_key="your_comfy_api_key_here",
|
||||
timeout=30.0,
|
||||
verify_ssl=True
|
||||
)
|
||||
@@ -93,15 +94,19 @@ from __future__ import annotations
|
||||
import logging
|
||||
import time
|
||||
import io
|
||||
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable
|
||||
import socket
|
||||
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
|
||||
from enum import Enum
|
||||
import json
|
||||
import requests
|
||||
from urllib.parse import urljoin
|
||||
from urllib.parse import urljoin, urlparse
|
||||
from pydantic import BaseModel, Field
|
||||
import uuid # For generating unique operation IDs
|
||||
|
||||
from server import PromptServer
|
||||
from comfy.cli_args import args
|
||||
from comfy import utils
|
||||
from . import request_logger
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
R = TypeVar("R", bound=BaseModel)
|
||||
@@ -110,6 +115,21 @@ P = TypeVar("P", bound=BaseModel) # For poll response
|
||||
PROGRESS_BAR_MAX = 100
|
||||
|
||||
|
||||
class NetworkError(Exception):
|
||||
"""Base exception for network-related errors with diagnostic information."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalNetworkError(NetworkError):
|
||||
"""Exception raised when local network connectivity issues are detected."""
|
||||
pass
|
||||
|
||||
|
||||
class ApiServerError(NetworkError):
|
||||
"""Exception raised when the API server is unreachable but internet is working."""
|
||||
pass
|
||||
|
||||
|
||||
class EmptyRequest(BaseModel):
|
||||
"""Base class for empty request bodies.
|
||||
For GET requests, fields will be sent as query parameters."""
|
||||
@@ -119,7 +139,7 @@ class EmptyRequest(BaseModel):
|
||||
|
||||
class UploadRequest(BaseModel):
|
||||
file_name: str = Field(..., description="Filename to upload")
|
||||
content_type: str | None = Field(
|
||||
content_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||
)
|
||||
@@ -140,20 +160,36 @@ class HttpMethod(str, Enum):
|
||||
|
||||
class ApiClient:
|
||||
"""
|
||||
Client for making HTTP requests to an API with authentication and error handling.
|
||||
Client for making HTTP requests to an API with authentication, error handling, and retry logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: Optional[str] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
timeout: float = 3600.0,
|
||||
verify_ssl: bool = True,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
retry_status_codes: Optional[Tuple[int, ...]] = None,
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
self.timeout = timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
|
||||
# 500, 502, 503, 504 (Server Errors)
|
||||
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
|
||||
|
||||
def _generate_operation_id(self, path: str) -> str:
|
||||
"""Generates a unique operation ID for logging."""
|
||||
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
def _create_json_payload_args(
|
||||
self,
|
||||
@@ -201,11 +237,63 @@ class ApiClient:
|
||||
"""Get headers for API requests, including authentication if available"""
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
if self.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||
elif self.comfy_api_key:
|
||||
headers["X-API-KEY"] = self.comfy_api_key
|
||||
|
||||
return headers
|
||||
|
||||
def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
|
||||
"""
|
||||
Check connectivity to determine if network issues are local or server-related.
|
||||
|
||||
Args:
|
||||
target_url: URL to check connectivity to
|
||||
|
||||
Returns:
|
||||
Dictionary with connectivity status details
|
||||
"""
|
||||
results = {
|
||||
"internet_accessible": False,
|
||||
"api_accessible": False,
|
||||
"is_local_issue": False,
|
||||
"is_api_issue": False
|
||||
}
|
||||
|
||||
# First check basic internet connectivity using a reliable external site
|
||||
try:
|
||||
# Use a reliable external domain for checking basic connectivity
|
||||
check_response = requests.get("https://www.google.com",
|
||||
timeout=5.0,
|
||||
verify=self.verify_ssl)
|
||||
if check_response.status_code < 500:
|
||||
results["internet_accessible"] = True
|
||||
except (requests.RequestException, socket.error):
|
||||
results["internet_accessible"] = False
|
||||
results["is_local_issue"] = True
|
||||
return results
|
||||
|
||||
# Now check API server connectivity
|
||||
try:
|
||||
# Extract domain from the target URL to do a simpler health check
|
||||
parsed_url = urlparse(target_url)
|
||||
api_base = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
# Try to reach the API domain
|
||||
api_response = requests.get(f"{api_base}/health", timeout=5.0, verify=self.verify_ssl)
|
||||
if api_response.status_code < 500:
|
||||
results["api_accessible"] = True
|
||||
else:
|
||||
results["api_accessible"] = False
|
||||
results["is_api_issue"] = True
|
||||
except requests.RequestException:
|
||||
results["api_accessible"] = False
|
||||
# If we can reach the internet but not the API, it's an API issue
|
||||
results["is_api_issue"] = True
|
||||
|
||||
return results
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
@@ -216,9 +304,10 @@ class ApiClient:
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
content_type: str = "application/json",
|
||||
multipart_parser: Callable = None,
|
||||
retry_count: int = 0, # Used internally for tracking retries
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Make an HTTP request to the API
|
||||
Make an HTTP request to the API with automatic retries for transient errors.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
@@ -228,15 +317,20 @@ class ApiClient:
|
||||
files: Files to upload
|
||||
headers: Additional headers
|
||||
content_type: Content type of the request. Defaults to application/json.
|
||||
retry_count: Internal parameter for tracking retries, do not set manually
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
|
||||
Raises:
|
||||
requests.RequestException: If the request fails
|
||||
LocalNetworkError: If local network connectivity issues are detected
|
||||
ApiServerError: If the API server is unreachable but internet is working
|
||||
Exception: For other request failures
|
||||
"""
|
||||
url = urljoin(self.base_url, path)
|
||||
self.check_auth_token(self.api_key)
|
||||
# Use urljoin but ensure path is relative to avoid absolute path behavior
|
||||
relative_path = path.lstrip('/')
|
||||
url = urljoin(self.base_url, relative_path)
|
||||
self.check_auth(self.auth_token, self.comfy_api_key)
|
||||
# Combine default headers with any provided headers
|
||||
request_headers = self.get_headers()
|
||||
if headers:
|
||||
@@ -260,6 +354,16 @@ class ApiClient:
|
||||
else:
|
||||
payload_args = self._create_json_payload_args(data, request_headers)
|
||||
|
||||
operation_id = self._generate_operation_id(path)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=request_headers,
|
||||
request_params=params,
|
||||
request_data=data if content_type == "application/json" else "[form-data or other]"
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.request(
|
||||
method=method,
|
||||
@@ -270,87 +374,365 @@ class ApiClient:
|
||||
**payload_args,
|
||||
)
|
||||
|
||||
# Check if we should retry based on status code
|
||||
if (response.status_code in self.retry_status_codes and
|
||||
retry_count < self.max_retries):
|
||||
|
||||
# Calculate delay with exponential backoff
|
||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||||
|
||||
logging.warning(
|
||||
f"Request failed with status {response.status_code}. "
|
||||
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
|
||||
)
|
||||
|
||||
time.sleep(delay)
|
||||
return self.request(
|
||||
method=method,
|
||||
path=path,
|
||||
params=params,
|
||||
data=data,
|
||||
files=files,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
multipart_parser=multipart_parser,
|
||||
retry_count=retry_count + 1,
|
||||
)
|
||||
|
||||
# Raise exception for error status codes
|
||||
response.raise_for_status()
|
||||
except requests.ConnectionError:
|
||||
raise Exception(
|
||||
f"Unable to connect to the API server at {self.base_url}. Please check your internet connection or verify the service is available."
|
||||
|
||||
# Log successful response
|
||||
response_content_to_log = response.content
|
||||
try:
|
||||
# Attempt to parse JSON for prettier logging, fallback to raw content
|
||||
response_content_to_log = response.json()
|
||||
except json.JSONDecodeError:
|
||||
pass # Keep as bytes/str if not JSON
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method, # Pass request details again for context in log
|
||||
request_url=url,
|
||||
response_status_code=response.status_code,
|
||||
response_headers=dict(response.headers),
|
||||
response_content=response_content_to_log
|
||||
)
|
||||
|
||||
except requests.Timeout:
|
||||
raise Exception(
|
||||
f"Request timed out after {self.timeout} seconds. The server might be experiencing high load or the operation is taking longer than expected."
|
||||
except requests.ConnectionError as e:
|
||||
error_message = f"ConnectionError: {str(e)}"
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
error_message=error_message
|
||||
)
|
||||
# Only perform connectivity check if we've exhausted all retries
|
||||
if retry_count >= self.max_retries:
|
||||
# Check connectivity to determine if it's a local or API issue
|
||||
connectivity = self._check_connectivity(self.base_url)
|
||||
|
||||
if connectivity["is_local_issue"]:
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
elif connectivity["is_api_issue"]:
|
||||
raise ApiServerError(
|
||||
f"The API server at {self.base_url} is currently unreachable. "
|
||||
f"The service may be experiencing issues. Please try again later."
|
||||
) from e
|
||||
|
||||
# If we haven't exhausted retries yet, retry the request
|
||||
if retry_count < self.max_retries:
|
||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||||
logging.warning(
|
||||
f"Connection error: {str(e)}. "
|
||||
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
|
||||
)
|
||||
time.sleep(delay)
|
||||
return self.request(
|
||||
method=method,
|
||||
path=path,
|
||||
params=params,
|
||||
data=data,
|
||||
files=files,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
multipart_parser=multipart_parser,
|
||||
retry_count=retry_count + 1,
|
||||
)
|
||||
|
||||
# If we've exhausted retries and didn't identify the specific issue,
|
||||
# raise a generic exception
|
||||
final_error_message = (
|
||||
f"Unable to connect to the API server after {self.max_retries} attempts. "
|
||||
f"Please check your internet connection or try again later."
|
||||
)
|
||||
request_logger.log_request_response( # Log final failure
|
||||
operation_id=operation_id,
|
||||
request_method=method, request_url=url,
|
||||
error_message=final_error_message
|
||||
)
|
||||
raise Exception(final_error_message) from e
|
||||
|
||||
except requests.Timeout as e:
|
||||
error_message = f"Timeout: {str(e)}"
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method, request_url=url,
|
||||
error_message=error_message
|
||||
)
|
||||
# Retry timeouts if we haven't exhausted retries
|
||||
if retry_count < self.max_retries:
|
||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||||
logging.warning(
|
||||
f"Request timed out. "
|
||||
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
|
||||
)
|
||||
time.sleep(delay)
|
||||
return self.request(
|
||||
method=method,
|
||||
path=path,
|
||||
params=params,
|
||||
data=data,
|
||||
files=files,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
multipart_parser=multipart_parser,
|
||||
retry_count=retry_count + 1,
|
||||
)
|
||||
final_error_message = (
|
||||
f"Request timed out after {self.timeout} seconds and {self.max_retries} retry attempts. "
|
||||
f"The server might be experiencing high load or the operation is taking longer than expected."
|
||||
)
|
||||
request_logger.log_request_response( # Log final failure
|
||||
operation_id=operation_id,
|
||||
request_method=method, request_url=url,
|
||||
error_message=final_error_message
|
||||
)
|
||||
raise Exception(final_error_message) from e
|
||||
|
||||
except requests.HTTPError as e:
|
||||
status_code = e.response.status_code if hasattr(e, "response") else None
|
||||
error_message = f"HTTP Error: {str(e)}"
|
||||
original_error_message = f"HTTP Error: {str(e)}"
|
||||
error_content_for_log = None
|
||||
if hasattr(e, "response") and e.response is not None:
|
||||
error_content_for_log = e.response.content
|
||||
try:
|
||||
error_content_for_log = e.response.json()
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
|
||||
# Try to extract detailed error message from JSON response for user display
|
||||
# but log the full error content.
|
||||
user_display_error_message = original_error_message
|
||||
|
||||
# Try to extract detailed error message from JSON response
|
||||
try:
|
||||
if hasattr(e, "response") and e.response.content:
|
||||
if hasattr(e, "response") and e.response is not None and e.response.content:
|
||||
error_json = e.response.json()
|
||||
if "error" in error_json and "message" in error_json["error"]:
|
||||
error_message = f"API Error: {error_json['error']['message']}"
|
||||
user_display_error_message = f"API Error: {error_json['error']['message']}"
|
||||
if "type" in error_json["error"]:
|
||||
error_message += f" (Type: {error_json['error']['type']})"
|
||||
user_display_error_message += f" (Type: {error_json['error']['type']})"
|
||||
elif isinstance(error_json, dict): # Handle cases where error is just a JSON dict
|
||||
user_display_error_message = f"API Error: {json.dumps(error_json)}"
|
||||
else: # Non-dict JSON error
|
||||
user_display_error_message = f"API Error: {str(error_json)}"
|
||||
except json.JSONDecodeError:
|
||||
# If not JSON, use the raw content if it's not too long, or a summary
|
||||
if hasattr(e, "response") and e.response is not None and e.response.content:
|
||||
raw_content = e.response.content.decode(errors='ignore')
|
||||
if len(raw_content) < 200: # Arbitrary limit for display
|
||||
user_display_error_message = f"API Error (raw): {raw_content}"
|
||||
else:
|
||||
error_message = f"API Error: {error_json}"
|
||||
except Exception as json_error:
|
||||
# If we can't parse the JSON, fall back to the original error message
|
||||
logging.debug(
|
||||
f"[DEBUG] Failed to parse error response: {str(json_error)}"
|
||||
user_display_error_message = f"API Error (raw, status {status_code})"
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method, request_url=url,
|
||||
response_status_code=status_code,
|
||||
response_headers=dict(e.response.headers) if hasattr(e, "response") and e.response is not None else None,
|
||||
response_content=error_content_for_log,
|
||||
error_message=original_error_message # Log the original exception string as error
|
||||
)
|
||||
|
||||
logging.debug(f"[DEBUG] API Error: {user_display_error_message} (Status: {status_code})")
|
||||
if hasattr(e, "response") and e.response is not None and e.response.content:
|
||||
logging.debug(f"[DEBUG] Response content: {e.response.content}")
|
||||
|
||||
# Retry if the status code is in our retry list and we haven't exhausted retries
|
||||
if (status_code in self.retry_status_codes and
|
||||
retry_count < self.max_retries):
|
||||
|
||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||||
logging.warning(
|
||||
f"HTTP error {status_code}. "
|
||||
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
|
||||
)
|
||||
time.sleep(delay)
|
||||
return self.request(
|
||||
method=method,
|
||||
path=path,
|
||||
params=params,
|
||||
data=data,
|
||||
files=files,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
multipart_parser=multipart_parser,
|
||||
retry_count=retry_count + 1,
|
||||
)
|
||||
|
||||
logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})")
|
||||
if hasattr(e, "response") and e.response.content:
|
||||
logging.debug(f"[DEBUG] Response content: {e.response.content}")
|
||||
# Specific error messages for common status codes for user display
|
||||
if status_code == 401:
|
||||
error_message = "Unauthorized: Please login first to use this node."
|
||||
if status_code == 402:
|
||||
error_message = "Payment Required: Please add credits to your account to use this node."
|
||||
if status_code == 409:
|
||||
error_message = "There is a problem with your account. Please contact support@comfy.org. "
|
||||
if status_code == 429:
|
||||
error_message = "Rate Limit Exceeded: Please try again later."
|
||||
raise Exception(error_message)
|
||||
user_display_error_message = "Unauthorized: Please login first to use this node."
|
||||
elif status_code == 402:
|
||||
user_display_error_message = "Payment Required: Please add credits to your account to use this node."
|
||||
elif status_code == 409:
|
||||
user_display_error_message = "There is a problem with your account. Please contact support@comfy.org."
|
||||
elif status_code == 429:
|
||||
user_display_error_message = "Rate Limit Exceeded: Please try again later."
|
||||
# else, user_display_error_message remains as parsed from response or original HTTPError string
|
||||
|
||||
raise Exception(user_display_error_message) # Raise with the user-friendly message
|
||||
|
||||
# Parse and return JSON response
|
||||
if response.content:
|
||||
return response.json()
|
||||
return {}
|
||||
|
||||
def check_auth_token(self, auth_token):
|
||||
"""Verify that an auth token is present."""
|
||||
if auth_token is None:
|
||||
def check_auth(self, auth_token, comfy_api_key):
|
||||
"""Verify that an auth token is present or comfy_api_key is present"""
|
||||
if auth_token is None and comfy_api_key is None:
|
||||
raise Exception("Unauthorized: Please login first to use this node.")
|
||||
return auth_token
|
||||
return auth_token or comfy_api_key
|
||||
|
||||
@staticmethod
|
||||
def upload_file(
|
||||
upload_url: str,
|
||||
file: io.BytesIO | str,
|
||||
content_type: str | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
):
|
||||
"""Upload a file to the API. Make sure the file has a filename equal to what the url expects.
|
||||
"""Upload a file to the API with retry logic.
|
||||
|
||||
Args:
|
||||
upload_url: The URL to upload to
|
||||
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
|
||||
mime_type: Optional mime type to set for the upload
|
||||
content_type: Optional mime type to set for the upload
|
||||
max_retries: Maximum number of retry attempts
|
||||
retry_delay: Initial delay between retries in seconds
|
||||
retry_backoff_factor: Multiplier for the delay after each retry
|
||||
"""
|
||||
headers = {}
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
|
||||
# Prepare the file data
|
||||
if isinstance(file, io.BytesIO):
|
||||
file.seek(0) # Ensure we're at the start of the file
|
||||
data = file.read()
|
||||
return requests.put(upload_url, data=data, headers=headers)
|
||||
elif isinstance(file, str):
|
||||
with open(file, "rb") as f:
|
||||
data = f.read()
|
||||
return requests.put(upload_url, data=data, headers=headers)
|
||||
else:
|
||||
raise ValueError("File must be either a BytesIO object or a file path string")
|
||||
|
||||
# Try the upload with retries
|
||||
last_exception = None
|
||||
operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" # Simplified ID for uploads
|
||||
|
||||
# Log initial attempt (without full file data for brevity)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers,
|
||||
request_data=f"[File data of type {content_type or 'unknown'}, size {len(data)} bytes]"
|
||||
)
|
||||
|
||||
for retry_attempt in range(max_retries + 1):
|
||||
try:
|
||||
response = requests.put(upload_url, data=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT", request_url=upload_url, # For context
|
||||
response_status_code=response.status_code,
|
||||
response_headers=dict(response.headers),
|
||||
response_content="File uploaded successfully." # Or response.text if available
|
||||
)
|
||||
return response
|
||||
|
||||
except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e:
|
||||
last_exception = e
|
||||
error_message_for_log = f"{type(e).__name__}: {str(e)}"
|
||||
response_content_for_log = None
|
||||
status_code_for_log = None
|
||||
headers_for_log = None
|
||||
|
||||
if hasattr(e, 'response') and e.response is not None:
|
||||
status_code_for_log = e.response.status_code
|
||||
headers_for_log = dict(e.response.headers)
|
||||
try:
|
||||
response_content_for_log = e.response.json()
|
||||
except json.JSONDecodeError:
|
||||
response_content_for_log = e.response.content
|
||||
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT", request_url=upload_url,
|
||||
response_status_code=status_code_for_log,
|
||||
response_headers=headers_for_log,
|
||||
response_content=response_content_for_log,
|
||||
error_message=error_message_for_log
|
||||
)
|
||||
|
||||
if retry_attempt < max_retries:
|
||||
delay = retry_delay * (retry_backoff_factor ** retry_attempt)
|
||||
logging.warning(
|
||||
f"File upload failed: {str(e)}. "
|
||||
f"Retrying in {delay:.2f}s ({retry_attempt + 1}/{max_retries})"
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
break # Max retries reached
|
||||
|
||||
# If we've exhausted all retries, determine the final error type and raise
|
||||
final_error_message = f"Failed to upload file after {max_retries + 1} attempts. Error: {str(last_exception)}"
|
||||
try:
|
||||
# Check basic internet connectivity
|
||||
check_response = requests.get("https://www.google.com", timeout=5.0, verify=True) # Assuming verify=True is desired
|
||||
if check_response.status_code >= 500: # Google itself has an issue (rare)
|
||||
final_error_message = (f"Failed to upload file. Internet connectivity check to Google failed "
|
||||
f"(status {check_response.status_code}). Original error: {str(last_exception)}")
|
||||
# Not raising LocalNetworkError here as Google itself might be down.
|
||||
# If Google is reachable, the issue is likely with the upload server or a more specific local problem
|
||||
# not caught by a simple Google ping (e.g., DNS for the specific upload URL, firewall).
|
||||
# The original last_exception is probably most relevant.
|
||||
|
||||
except (requests.RequestException, socket.error) as conn_check_exc:
|
||||
# Could not reach Google, likely a local network issue
|
||||
final_error_message = (f"Failed to upload file due to network connectivity issues "
|
||||
f"(cannot reach Google: {str(conn_check_exc)}). "
|
||||
f"Original upload error: {str(last_exception)}")
|
||||
request_logger.log_request_response( # Log final failure reason
|
||||
operation_id=operation_id,
|
||||
request_method="PUT", request_url=upload_url,
|
||||
error_message=final_error_message
|
||||
)
|
||||
raise LocalNetworkError(final_error_message) from last_exception
|
||||
|
||||
request_logger.log_request_response( # Log final failure reason if not LocalNetworkError
|
||||
operation_id=operation_id,
|
||||
request_method="PUT", request_url=upload_url,
|
||||
error_message=final_error_message
|
||||
)
|
||||
raise Exception(final_error_message) from last_exception
|
||||
|
||||
|
||||
class ApiEndpoint(Generic[T, R]):
|
||||
@@ -392,10 +774,15 @@ class SynchronousOperation(Generic[T, R]):
|
||||
files: Optional[Dict[str, Any]] = None,
|
||||
api_base: str | None = None,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
auth_kwargs: Optional[Dict[str,str]] = None,
|
||||
timeout: float = 604800.0,
|
||||
verify_ssl: bool = True,
|
||||
content_type: str = "application/json",
|
||||
multipart_parser: Callable = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.request = request
|
||||
@@ -403,21 +790,33 @@ class SynchronousOperation(Generic[T, R]):
|
||||
self.error = None
|
||||
self.api_base: str = api_base or args.comfy_api_base
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
if auth_kwargs is not None:
|
||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
||||
self.timeout = timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
self.files = files
|
||||
self.content_type = content_type
|
||||
self.multipart_parser = multipart_parser
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
|
||||
def execute(self, client: Optional[ApiClient] = None) -> R:
|
||||
"""Execute the API operation using the provided client or create one"""
|
||||
"""Execute the API operation using the provided client or create one with retry support"""
|
||||
try:
|
||||
# Create client if not provided
|
||||
if client is None:
|
||||
client = ApiClient(
|
||||
base_url=self.api_base,
|
||||
api_key=self.auth_token,
|
||||
auth_token=self.auth_token,
|
||||
comfy_api_key=self.comfy_api_key,
|
||||
timeout=self.timeout,
|
||||
verify_ssl=self.verify_ssl,
|
||||
max_retries=self.max_retries,
|
||||
retry_delay=self.retry_delay,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
)
|
||||
|
||||
# Convert request model to dict, but use None for EmptyRequest
|
||||
@@ -431,11 +830,6 @@ class SynchronousOperation(Generic[T, R]):
|
||||
if isinstance(value, Enum):
|
||||
request_dict[key] = value.value
|
||||
|
||||
if request_dict:
|
||||
for key, value in request_dict.items():
|
||||
if isinstance(value, Enum):
|
||||
request_dict[key] = value.value
|
||||
|
||||
# Debug log for request
|
||||
logging.debug(
|
||||
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
|
||||
@@ -443,7 +837,7 @@ class SynchronousOperation(Generic[T, R]):
|
||||
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
|
||||
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
|
||||
|
||||
# Make the request
|
||||
# Make the request with built-in retry
|
||||
resp = client.request(
|
||||
method=self.endpoint.method.value,
|
||||
path=self.endpoint.path,
|
||||
@@ -464,8 +858,18 @@ class SynchronousOperation(Generic[T, R]):
|
||||
# Parse and return the response
|
||||
return self._parse_response(resp)
|
||||
|
||||
except LocalNetworkError as e:
|
||||
# Propagate specific network error types
|
||||
logging.error(f"[ERROR] Local network error: {str(e)}")
|
||||
raise
|
||||
|
||||
except ApiServerError as e:
|
||||
# Propagate API server errors
|
||||
logging.error(f"[ERROR] API server error: {str(e)}")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"[DEBUG] API Exception: {str(e)}")
|
||||
logging.error(f"[ERROR] API Exception: {str(e)}")
|
||||
raise Exception(str(e))
|
||||
|
||||
def _parse_response(self, resp):
|
||||
@@ -499,22 +903,42 @@ class PollingOperation(Generic[T, R]):
|
||||
failed_statuses: list,
|
||||
status_extractor: Callable[[R], str],
|
||||
progress_extractor: Callable[[R], float] = None,
|
||||
result_url_extractor: Callable[[R], str] = None,
|
||||
request: Optional[T] = None,
|
||||
api_base: str | None = None,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
auth_kwargs: Optional[Dict[str,str]] = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
|
||||
max_retries: int = 3, # Max retries per individual API call
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
estimated_duration: Optional[float] = None,
|
||||
node_id: Optional[str] = None,
|
||||
):
|
||||
self.poll_endpoint = poll_endpoint
|
||||
self.request = request
|
||||
self.api_base: str = api_base or args.comfy_api_base
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
if auth_kwargs is not None:
|
||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
||||
self.poll_interval = poll_interval
|
||||
self.max_poll_attempts = max_poll_attempts
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
self.estimated_duration = estimated_duration
|
||||
|
||||
# Polling configuration
|
||||
self.status_extractor = status_extractor or (
|
||||
lambda x: getattr(x, "status", None)
|
||||
)
|
||||
self.progress_extractor = progress_extractor
|
||||
self.result_url_extractor = result_url_extractor
|
||||
self.node_id = node_id
|
||||
self.completed_statuses = completed_statuses
|
||||
self.failed_statuses = failed_statuses
|
||||
|
||||
@@ -528,12 +952,48 @@ class PollingOperation(Generic[T, R]):
|
||||
if client is None:
|
||||
client = ApiClient(
|
||||
base_url=self.api_base,
|
||||
api_key=self.auth_token,
|
||||
auth_token=self.auth_token,
|
||||
comfy_api_key=self.comfy_api_key,
|
||||
max_retries=self.max_retries,
|
||||
retry_delay=self.retry_delay,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
)
|
||||
return self._poll_until_complete(client)
|
||||
except LocalNetworkError as e:
|
||||
# Provide clear message for local network issues
|
||||
raise Exception(
|
||||
f"Polling failed due to local network issues. Please check your internet connection. "
|
||||
f"Details: {str(e)}"
|
||||
) from e
|
||||
except ApiServerError as e:
|
||||
# Provide clear message for API server issues
|
||||
raise Exception(
|
||||
f"Polling failed due to API server issues. The service may be experiencing problems. "
|
||||
f"Please try again later. Details: {str(e)}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise Exception(f"Error during polling: {str(e)}")
|
||||
|
||||
def _display_text_on_node(self, text: str):
|
||||
"""Sends text to the client which will be displayed on the node in the UI"""
|
||||
if not self.node_id:
|
||||
return
|
||||
|
||||
PromptServer.instance.send_progress_text(text, self.node_id)
|
||||
|
||||
def _display_time_progress_on_node(self, time_completed: int):
|
||||
if not self.node_id:
|
||||
return
|
||||
|
||||
if self.estimated_duration is not None:
|
||||
estimated_time_remaining = max(
|
||||
0, int(self.estimated_duration) - int(time_completed)
|
||||
)
|
||||
message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)"
|
||||
else:
|
||||
message = f"Task in progress: {time_completed:.0f}s"
|
||||
self._display_text_on_node(message)
|
||||
|
||||
def _check_task_status(self, response: R) -> TaskStatus:
|
||||
"""Check task status using the status extractor function"""
|
||||
try:
|
||||
@@ -550,10 +1010,13 @@ class PollingOperation(Generic[T, R]):
|
||||
def _poll_until_complete(self, client: ApiClient) -> R:
|
||||
"""Poll until the task is complete"""
|
||||
poll_count = 0
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
|
||||
|
||||
if self.progress_extractor:
|
||||
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
|
||||
|
||||
while True:
|
||||
while poll_count < self.max_poll_attempts:
|
||||
try:
|
||||
poll_count += 1
|
||||
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
|
||||
@@ -580,8 +1043,12 @@ class PollingOperation(Generic[T, R]):
|
||||
data=request_dict,
|
||||
)
|
||||
|
||||
# Successfully got a response, reset consecutive error count
|
||||
consecutive_errors = 0
|
||||
|
||||
# Parse response
|
||||
response_obj = self.poll_endpoint.response_model.model_validate(resp)
|
||||
|
||||
# Check if task is complete
|
||||
status = self._check_task_status(response_obj)
|
||||
logging.debug(f"[DEBUG] Task Status: {status}")
|
||||
@@ -593,7 +1060,15 @@ class PollingOperation(Generic[T, R]):
|
||||
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
|
||||
|
||||
if status == TaskStatus.COMPLETED:
|
||||
logging.debug("[DEBUG] Task completed successfully")
|
||||
message = "Task completed successfully"
|
||||
if self.result_url_extractor:
|
||||
result_url = self.result_url_extractor(response_obj)
|
||||
if result_url:
|
||||
message = f"Result URL: {result_url}"
|
||||
else:
|
||||
message = "Task completed successfully!"
|
||||
logging.debug(f"[DEBUG] {message}")
|
||||
self._display_text_on_node(message)
|
||||
self.final_response = response_obj
|
||||
if self.progress_extractor:
|
||||
progress.update(100)
|
||||
@@ -609,8 +1084,43 @@ class PollingOperation(Generic[T, R]):
|
||||
logging.debug(
|
||||
f"[DEBUG] Waiting {self.poll_interval} seconds before next poll"
|
||||
)
|
||||
for i in range(int(self.poll_interval)):
|
||||
time_completed = (poll_count * self.poll_interval) + i
|
||||
self._display_time_progress_on_node(time_completed)
|
||||
time.sleep(1)
|
||||
|
||||
except (LocalNetworkError, ApiServerError) as e:
|
||||
# For network-related errors, increment error count and potentially abort
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
raise Exception(
|
||||
f"Polling aborted after {consecutive_errors} consecutive network errors: {str(e)}"
|
||||
) from e
|
||||
|
||||
# Log the error but continue polling
|
||||
logging.warning(
|
||||
f"Network error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
|
||||
f"Will retry in {self.poll_interval} seconds."
|
||||
)
|
||||
time.sleep(self.poll_interval)
|
||||
|
||||
except Exception as e:
|
||||
# For other errors, increment count and potentially abort
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
|
||||
raise Exception(
|
||||
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
|
||||
) from e
|
||||
|
||||
logging.error(f"[DEBUG] Polling error: {str(e)}")
|
||||
raise Exception(f"Error while polling: {str(e)}")
|
||||
logging.warning(
|
||||
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
|
||||
f"Will retry in {self.poll_interval} seconds."
|
||||
)
|
||||
time.sleep(self.poll_interval)
|
||||
|
||||
# If we've exhausted all polling attempts
|
||||
raise Exception(
|
||||
f"Polling timed out after {poll_count} attempts ({poll_count * self.poll_interval} seconds). "
|
||||
f"The operation may still be running on the server but is taking longer than expected."
|
||||
)
|
||||
|
||||
@@ -81,7 +81,6 @@ class RecraftStyle:
|
||||
|
||||
class RecraftIO:
|
||||
STYLEV3 = "RECRAFT_V3_STYLE"
|
||||
SVG = "SVG" # TODO: if acceptable, move into ComfyUI's typing class
|
||||
COLOR = "RECRAFT_COLOR"
|
||||
CONTROLS = "RECRAFT_CONTROLS"
|
||||
|
||||
|
||||
125
comfy_api_nodes/apis/request_logger.py
Normal file
125
comfy_api_nodes/apis/request_logger.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import os
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import folder_paths
|
||||
|
||||
# Get the logger instance
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_log_directory():
|
||||
"""
|
||||
Ensures the API log directory exists within ComfyUI's temp directory
|
||||
and returns its path.
|
||||
"""
|
||||
base_temp_dir = folder_paths.get_temp_directory()
|
||||
log_dir = os.path.join(base_temp_dir, "api_logs")
|
||||
try:
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating API log directory {log_dir}: {e}")
|
||||
# Fallback to base temp directory if sub-directory creation fails
|
||||
return base_temp_dir
|
||||
return log_dir
|
||||
|
||||
def _format_data_for_logging(data):
|
||||
"""Helper to format data (dict, str, bytes) for logging."""
|
||||
if isinstance(data, bytes):
|
||||
try:
|
||||
return data.decode('utf-8') # Try to decode as text
|
||||
except UnicodeDecodeError:
|
||||
return f"[Binary data of length {len(data)} bytes]"
|
||||
elif isinstance(data, (dict, list)):
|
||||
try:
|
||||
return json.dumps(data, indent=2, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(data) # Fallback for non-serializable objects
|
||||
return str(data)
|
||||
|
||||
def log_request_response(
|
||||
operation_id: str,
|
||||
request_method: str,
|
||||
request_url: str,
|
||||
request_headers: dict | None = None,
|
||||
request_params: dict | None = None,
|
||||
request_data: any = None,
|
||||
response_status_code: int | None = None,
|
||||
response_headers: dict | None = None,
|
||||
response_content: any = None,
|
||||
error_message: str | None = None
|
||||
):
|
||||
"""
|
||||
Logs API request and response details to a file in the temp/api_logs directory.
|
||||
"""
|
||||
log_dir = get_log_directory()
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log"
|
||||
filepath = os.path.join(log_dir, filename)
|
||||
|
||||
log_content = []
|
||||
|
||||
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
|
||||
log_content.append(f"Operation ID: {operation_id}")
|
||||
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
|
||||
log_content.append(f"Method: {request_method}")
|
||||
log_content.append(f"URL: {request_url}")
|
||||
if request_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
|
||||
if request_params:
|
||||
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
|
||||
if request_data:
|
||||
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
|
||||
|
||||
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
|
||||
if response_status_code is not None:
|
||||
log_content.append(f"Status Code: {response_status_code}")
|
||||
if response_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
|
||||
if response_content:
|
||||
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
|
||||
if error_message:
|
||||
log_content.append(f"Error:\n{error_message}")
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_content))
|
||||
logger.debug(f"API log saved to: {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing API log to {filepath}: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Example usage (for testing the logger directly)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
# Mock folder_paths for direct execution if not running within ComfyUI full context
|
||||
if not hasattr(folder_paths, 'get_temp_directory'):
|
||||
class MockFolderPaths:
|
||||
def get_temp_directory(self):
|
||||
# Create a local temp dir for testing if needed
|
||||
p = os.path.join(os.path.dirname(__file__), 'temp_test_logs')
|
||||
os.makedirs(p, exist_ok=True)
|
||||
return p
|
||||
folder_paths = MockFolderPaths()
|
||||
|
||||
log_request_response(
|
||||
operation_id="test_operation_get",
|
||||
request_method="GET",
|
||||
request_url="https://api.example.com/test",
|
||||
request_headers={"Authorization": "Bearer testtoken"},
|
||||
request_params={"param1": "value1"},
|
||||
response_status_code=200,
|
||||
response_content={"message": "Success!"}
|
||||
)
|
||||
log_request_response(
|
||||
operation_id="test_operation_post_error",
|
||||
request_method="POST",
|
||||
request_url="https://api.example.com/submit",
|
||||
request_data={"key": "value", "nested": {"num": 123}},
|
||||
error_message="Connection timed out"
|
||||
)
|
||||
log_request_response(
|
||||
operation_id="test_binary_response",
|
||||
request_method="GET",
|
||||
request_url="https://api.example.com/image.png",
|
||||
response_status_code=200,
|
||||
response_content=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR...' # Sample binary data
|
||||
)
|
||||
57
comfy_api_nodes/apis/rodin_api.py
Normal file
57
comfy_api_nodes/apis/rodin_api.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Rodin3DGenerateRequest(BaseModel):
|
||||
seed: int = Field(..., description="seed_")
|
||||
tier: str = Field(..., description="Tier of generation.")
|
||||
material: str = Field(..., description="The material type.")
|
||||
quality: str = Field(..., description="The generation quality of the mesh.")
|
||||
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
|
||||
|
||||
class GenerateJobsData(BaseModel):
|
||||
uuids: List[str] = Field(..., description="str LIST")
|
||||
subscription_key: str = Field(..., description="subscription key")
|
||||
|
||||
class Rodin3DGenerateResponse(BaseModel):
|
||||
message: Optional[str] = Field(None, description="Return message.")
|
||||
prompt: Optional[str] = Field(None, description="Generated Prompt from image.")
|
||||
submit_time: Optional[str] = Field(None, description="Submit Time")
|
||||
uuid: Optional[str] = Field(None, description="Task str")
|
||||
jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs")
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
"""
|
||||
Status for jobs
|
||||
"""
|
||||
Done = "Done"
|
||||
Failed = "Failed"
|
||||
Generating = "Generating"
|
||||
Waiting = "Waiting"
|
||||
|
||||
class Rodin3DCheckStatusRequest(BaseModel):
|
||||
subscription_key: str = Field(..., description="subscription from generate endpoint")
|
||||
|
||||
class JobItem(BaseModel):
|
||||
uuid: str = Field(..., description="uuid")
|
||||
status: JobStatus = Field(...,description="Status Currently")
|
||||
|
||||
class Rodin3DCheckStatusResponse(BaseModel):
|
||||
jobs: List[JobItem] = Field(..., description="Job status List")
|
||||
|
||||
class Rodin3DDownloadRequest(BaseModel):
|
||||
task_uuid: str = Field(..., description="Task str")
|
||||
|
||||
class RodinResourceItem(BaseModel):
|
||||
url: str = Field(..., description="Download Url")
|
||||
name: str = Field(..., description="File name with ext")
|
||||
|
||||
class Rodin3DDownloadResponse(BaseModel):
|
||||
list: List[RodinResourceItem] = Field(..., description="Source List")
|
||||
|
||||
|
||||
|
||||
|
||||
275
comfy_api_nodes/apis/tripo_api.py
Normal file
275
comfy_api_nodes/apis/tripo_api.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
from comfy_api_nodes.apis import (
|
||||
TripoModelVersion,
|
||||
TripoTextureQuality,
|
||||
)
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
|
||||
class TripoStyle(str, Enum):
|
||||
PERSON_TO_CARTOON = "person:person2cartoon"
|
||||
ANIMAL_VENOM = "animal:venom"
|
||||
OBJECT_CLAY = "object:clay"
|
||||
OBJECT_STEAMPUNK = "object:steampunk"
|
||||
OBJECT_CHRISTMAS = "object:christmas"
|
||||
OBJECT_BARBIE = "object:barbie"
|
||||
GOLD = "gold"
|
||||
ANCIENT_BRONZE = "ancient_bronze"
|
||||
NONE = "None"
|
||||
|
||||
class TripoTaskType(str, Enum):
|
||||
TEXT_TO_MODEL = "text_to_model"
|
||||
IMAGE_TO_MODEL = "image_to_model"
|
||||
MULTIVIEW_TO_MODEL = "multiview_to_model"
|
||||
TEXTURE_MODEL = "texture_model"
|
||||
REFINE_MODEL = "refine_model"
|
||||
ANIMATE_PRERIGCHECK = "animate_prerigcheck"
|
||||
ANIMATE_RIG = "animate_rig"
|
||||
ANIMATE_RETARGET = "animate_retarget"
|
||||
STYLIZE_MODEL = "stylize_model"
|
||||
CONVERT_MODEL = "convert_model"
|
||||
|
||||
class TripoTextureAlignment(str, Enum):
|
||||
ORIGINAL_IMAGE = "original_image"
|
||||
GEOMETRY = "geometry"
|
||||
|
||||
class TripoOrientation(str, Enum):
|
||||
ALIGN_IMAGE = "align_image"
|
||||
DEFAULT = "default"
|
||||
|
||||
class TripoOutFormat(str, Enum):
|
||||
GLB = "glb"
|
||||
FBX = "fbx"
|
||||
|
||||
class TripoTopology(str, Enum):
|
||||
BIP = "bip"
|
||||
QUAD = "quad"
|
||||
|
||||
class TripoSpec(str, Enum):
|
||||
MIXAMO = "mixamo"
|
||||
TRIPO = "tripo"
|
||||
|
||||
class TripoAnimation(str, Enum):
|
||||
IDLE = "preset:idle"
|
||||
WALK = "preset:walk"
|
||||
CLIMB = "preset:climb"
|
||||
JUMP = "preset:jump"
|
||||
RUN = "preset:run"
|
||||
SLASH = "preset:slash"
|
||||
SHOOT = "preset:shoot"
|
||||
HURT = "preset:hurt"
|
||||
FALL = "preset:fall"
|
||||
TURN = "preset:turn"
|
||||
|
||||
class TripoStylizeStyle(str, Enum):
|
||||
LEGO = "lego"
|
||||
VOXEL = "voxel"
|
||||
VORONOI = "voronoi"
|
||||
MINECRAFT = "minecraft"
|
||||
|
||||
class TripoConvertFormat(str, Enum):
|
||||
GLTF = "GLTF"
|
||||
USDZ = "USDZ"
|
||||
FBX = "FBX"
|
||||
OBJ = "OBJ"
|
||||
STL = "STL"
|
||||
_3MF = "3MF"
|
||||
|
||||
class TripoTextureFormat(str, Enum):
|
||||
BMP = "BMP"
|
||||
DPX = "DPX"
|
||||
HDR = "HDR"
|
||||
JPEG = "JPEG"
|
||||
OPEN_EXR = "OPEN_EXR"
|
||||
PNG = "PNG"
|
||||
TARGA = "TARGA"
|
||||
TIFF = "TIFF"
|
||||
WEBP = "WEBP"
|
||||
|
||||
class TripoTaskStatus(str, Enum):
|
||||
QUEUED = "queued"
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
UNKNOWN = "unknown"
|
||||
BANNED = "banned"
|
||||
EXPIRED = "expired"
|
||||
|
||||
class TripoFileTokenReference(BaseModel):
|
||||
type: Optional[str] = Field(None, description='The type of the reference')
|
||||
file_token: str
|
||||
|
||||
class TripoUrlReference(BaseModel):
|
||||
type: Optional[str] = Field(None, description='The type of the reference')
|
||||
url: str
|
||||
|
||||
class TripoObjectStorage(BaseModel):
|
||||
bucket: str
|
||||
key: str
|
||||
|
||||
class TripoObjectReference(BaseModel):
|
||||
type: str
|
||||
object: TripoObjectStorage
|
||||
|
||||
class TripoFileEmptyReference(BaseModel):
|
||||
pass
|
||||
|
||||
class TripoFileReference(RootModel):
|
||||
root: Union[TripoFileTokenReference, TripoUrlReference, TripoObjectReference, TripoFileEmptyReference]
|
||||
|
||||
class TripoGetStsTokenRequest(BaseModel):
|
||||
format: str = Field(..., description='The format of the image')
|
||||
|
||||
class TripoTextToModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
|
||||
prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
|
||||
negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
|
||||
model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
|
||||
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
|
||||
image_seed: Optional[int] = Field(None, description='The seed for the text')
|
||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||
style: Optional[TripoStyle] = None
|
||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||
|
||||
class TripoImageToModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description='Type of task')
|
||||
file: TripoFileReference = Field(..., description='The file reference to convert to a model')
|
||||
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
|
||||
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
|
||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
||||
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
|
||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||
orientation: Optional[TripoOrientation] = TripoOrientation.DEFAULT
|
||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||
|
||||
class TripoMultiviewToModelRequest(BaseModel):
|
||||
type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL
|
||||
files: List[TripoFileReference] = Field(..., description='The file references to convert to a model')
|
||||
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
|
||||
orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection')
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
|
||||
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
|
||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
|
||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
|
||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||
|
||||
class TripoTextureModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description='Type of task')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
texture: Optional[bool] = Field(True, description='Whether to apply texture to the model')
|
||||
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the model')
|
||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = Field(None, description='The quality of the texture')
|
||||
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
||||
|
||||
class TripoRefineModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description='Type of task')
|
||||
draft_model_task_id: str = Field(..., description='The task ID of the draft model')
|
||||
|
||||
class TripoAnimatePrerigcheckRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.ANIMATE_PRERIGCHECK, description='Type of task')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
|
||||
class TripoAnimateRigRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description='Type of task')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
|
||||
spec: Optional[TripoSpec] = Field(TripoSpec.TRIPO, description='The specification for rigging')
|
||||
|
||||
class TripoAnimateRetargetRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description='Type of task')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
animation: TripoAnimation = Field(..., description='The animation to apply')
|
||||
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
|
||||
bake_animation: Optional[bool] = Field(True, description='Whether to bake the animation')
|
||||
|
||||
class TripoStylizeModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.STYLIZE_MODEL, description='Type of task')
|
||||
style: TripoStylizeStyle = Field(..., description='The style to apply to the model')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
block_size: Optional[int] = Field(80, description='The block size for stylization')
|
||||
|
||||
class TripoConvertModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
|
||||
format: TripoConvertFormat = Field(..., description='The format to convert to')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
|
||||
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
|
||||
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
|
||||
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
|
||||
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
|
||||
texture_size: Optional[int] = Field(4096, description='The size of the texture')
|
||||
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
|
||||
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
|
||||
|
||||
class TripoTaskRequest(RootModel):
|
||||
root: Union[
|
||||
TripoTextToModelRequest,
|
||||
TripoImageToModelRequest,
|
||||
TripoMultiviewToModelRequest,
|
||||
TripoTextureModelRequest,
|
||||
TripoRefineModelRequest,
|
||||
TripoAnimatePrerigcheckRequest,
|
||||
TripoAnimateRigRequest,
|
||||
TripoAnimateRetargetRequest,
|
||||
TripoStylizeModelRequest,
|
||||
TripoConvertModelRequest
|
||||
]
|
||||
|
||||
class TripoTaskOutput(BaseModel):
|
||||
model: Optional[str] = Field(None, description='URL to the model')
|
||||
base_model: Optional[str] = Field(None, description='URL to the base model')
|
||||
pbr_model: Optional[str] = Field(None, description='URL to the PBR model')
|
||||
rendered_image: Optional[str] = Field(None, description='URL to the rendered image')
|
||||
riggable: Optional[bool] = Field(None, description='Whether the model is riggable')
|
||||
|
||||
class TripoTask(BaseModel):
|
||||
task_id: str = Field(..., description='The task ID')
|
||||
type: Optional[str] = Field(None, description='The type of task')
|
||||
status: Optional[TripoTaskStatus] = Field(None, description='The status of the task')
|
||||
input: Optional[Dict[str, Any]] = Field(None, description='The input parameters for the task')
|
||||
output: Optional[TripoTaskOutput] = Field(None, description='The output of the task')
|
||||
progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100)
|
||||
create_time: Optional[int] = Field(None, description='The creation time of the task')
|
||||
running_left_time: Optional[int] = Field(None, description='The estimated time left for the task')
|
||||
queue_position: Optional[int] = Field(None, description='The position in the queue')
|
||||
|
||||
class TripoTaskResponse(BaseModel):
|
||||
code: int = Field(0, description='The response code')
|
||||
data: TripoTask = Field(..., description='The task data')
|
||||
|
||||
class TripoGeneralResponse(BaseModel):
|
||||
code: int = Field(0, description='The response code')
|
||||
data: Dict[str, str] = Field(..., description='The task ID data')
|
||||
|
||||
class TripoBalanceData(BaseModel):
|
||||
balance: float = Field(..., description='The account balance')
|
||||
frozen: float = Field(..., description='The frozen balance')
|
||||
|
||||
class TripoBalanceResponse(BaseModel):
|
||||
code: int = Field(0, description='The response code')
|
||||
data: TripoBalanceData = Field(..., description='The balance data')
|
||||
|
||||
class TripoErrorResponse(BaseModel):
|
||||
code: int = Field(..., description='The error code')
|
||||
message: str = Field(..., description='The error message')
|
||||
suggestion: str = Field(..., description='The suggestion for fixing the error')
|
||||
@@ -1,5 +1,6 @@
|
||||
import io
|
||||
from inspect import cleandoc
|
||||
from typing import Union, Optional
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api_nodes.apis.bfl_api import (
|
||||
BFLStatus,
|
||||
@@ -8,6 +9,7 @@ from comfy_api_nodes.apis.bfl_api import (
|
||||
BFLFluxCannyImageRequest,
|
||||
BFLFluxDepthImageRequest,
|
||||
BFLFluxProGenerateRequest,
|
||||
BFLFluxKontextProGenerateRequest,
|
||||
BFLFluxProUltraGenerateRequest,
|
||||
BFLFluxProGenerateResponse,
|
||||
)
|
||||
@@ -30,6 +32,7 @@ import requests
|
||||
import torch
|
||||
import base64
|
||||
import time
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
def convert_mask_to_image(mask: torch.Tensor):
|
||||
@@ -42,14 +45,19 @@ def convert_mask_to_image(mask: torch.Tensor):
|
||||
|
||||
|
||||
def handle_bfl_synchronous_operation(
|
||||
operation: SynchronousOperation, timeout_bfl_calls=360
|
||||
operation: SynchronousOperation,
|
||||
timeout_bfl_calls=360,
|
||||
node_id: Union[str, None] = None,
|
||||
):
|
||||
response_api: BFLFluxProGenerateResponse = operation.execute()
|
||||
return _poll_until_generated(
|
||||
response_api.polling_url, timeout=timeout_bfl_calls
|
||||
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
|
||||
)
|
||||
|
||||
def _poll_until_generated(polling_url: str, timeout=360):
|
||||
|
||||
def _poll_until_generated(
|
||||
polling_url: str, timeout=360, node_id: Union[str, None] = None
|
||||
):
|
||||
# used bfl-comfy-nodes to verify code implementation:
|
||||
# https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
|
||||
start_time = time.time()
|
||||
@@ -61,11 +69,21 @@ def _poll_until_generated(polling_url: str, timeout=360):
|
||||
request = requests.Request(method=HttpMethod.GET, url=polling_url)
|
||||
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
|
||||
while True:
|
||||
if node_id:
|
||||
time_elapsed = time.time() - start_time
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generating ({time_elapsed:.0f}s)", node_id
|
||||
)
|
||||
|
||||
response = requests.Session().send(request.prepare())
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result["status"] == BFLStatus.ready:
|
||||
img_url = result["result"]["sample"]
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {img_url}", node_id
|
||||
)
|
||||
img_response = requests.get(img_url)
|
||||
return process_image_response(img_response)
|
||||
elif result["status"] in [
|
||||
@@ -179,6 +197,8 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -211,7 +231,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
||||
seed=0,
|
||||
image_prompt=None,
|
||||
image_prompt_strength=0.1,
|
||||
auth_token=None,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if image_prompt is None:
|
||||
@@ -244,12 +264,164 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
||||
None if image_prompt is None else round(image_prompt_strength, 2)
|
||||
),
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
class FluxKontextProImageNode(ComfyNodeABC):
|
||||
"""
|
||||
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
MINIMUM_RATIO = 1 / 4
|
||||
MAXIMUM_RATIO = 4 / 1
|
||||
MINIMUM_RATIO_STR = "1:4"
|
||||
MAXIMUM_RATIO_STR = "4:1"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation - specify what and how to edit.",
|
||||
},
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "16:9",
|
||||
"tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.",
|
||||
},
|
||||
),
|
||||
"guidance": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 3.0,
|
||||
"min": 0.1,
|
||||
"max": 99.0,
|
||||
"step": 0.1,
|
||||
"tooltip": "Guidance strength for the image generation process"
|
||||
},
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 50,
|
||||
"min": 1,
|
||||
"max": 150,
|
||||
"tooltip": "Number of steps for the image generation process"
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1234,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"input_image": (IO.IMAGE,),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, aspect_ratio: str):
|
||||
try:
|
||||
validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=cls.MINIMUM_RATIO,
|
||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
||||
)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
return True
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
guidance: float,
|
||||
steps: int,
|
||||
input_image: Optional[torch.Tensor]=None,
|
||||
seed=0,
|
||||
prompt_upsampling=False,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if input_image is None:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=self.BFL_PATH,
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxKontextProGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxKontextProGenerateRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
guidance=round(guidance, 1),
|
||||
steps=steps,
|
||||
seed=seed,
|
||||
aspect_ratio=validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=self.MINIMUM_RATIO,
|
||||
maximum_ratio=self.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=self.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
|
||||
),
|
||||
input_image=(
|
||||
input_image
|
||||
if input_image is None
|
||||
else convert_image_to_base64(input_image)
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
class FluxKontextMaxImageNode(FluxKontextProImageNode):
|
||||
"""
|
||||
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
|
||||
|
||||
|
||||
class FluxProImageNode(ComfyNodeABC):
|
||||
"""
|
||||
@@ -319,6 +491,8 @@ class FluxProImageNode(ComfyNodeABC):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -337,7 +511,7 @@ class FluxProImageNode(ComfyNodeABC):
|
||||
seed=0,
|
||||
image_prompt=None,
|
||||
# image_prompt_strength=0.1,
|
||||
auth_token=None,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
image_prompt = (
|
||||
@@ -361,9 +535,9 @@ class FluxProImageNode(ComfyNodeABC):
|
||||
seed=seed,
|
||||
image_prompt=image_prompt,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -457,10 +631,11 @@ class FluxProExpandNode(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -482,7 +657,7 @@ class FluxProExpandNode(ComfyNodeABC):
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
auth_token=None,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
image = convert_image_to_base64(image)
|
||||
@@ -506,9 +681,9 @@ class FluxProExpandNode(ComfyNodeABC):
|
||||
seed=seed,
|
||||
image=image,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -568,10 +743,11 @@ class FluxProFillNode(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -590,14 +766,14 @@ class FluxProFillNode(ComfyNodeABC):
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
auth_token=None,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# prepare mask
|
||||
mask = resize_mask_to_image(mask, image)
|
||||
mask = convert_image_to_base64(convert_mask_to_image(mask))
|
||||
# make sure image will have alpha channel removed
|
||||
image = convert_image_to_base64(image[:,:,:,:3])
|
||||
image = convert_image_to_base64(image[:, :, :, :3])
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -615,9 +791,9 @@ class FluxProFillNode(ComfyNodeABC):
|
||||
image=image,
|
||||
mask=mask,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -702,10 +878,11 @@ class FluxProCannyNode(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -726,10 +903,10 @@ class FluxProCannyNode(ComfyNodeABC):
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
auth_token=None,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
control_image = convert_image_to_base64(control_image[:,:,:,:3])
|
||||
control_image = convert_image_to_base64(control_image[:, :, :, :3])
|
||||
preprocessed_image = None
|
||||
|
||||
# scale canny threshold between 0-500, to match BFL's API
|
||||
@@ -763,9 +940,9 @@ class FluxProCannyNode(ComfyNodeABC):
|
||||
canny_high_threshold=canny_high_threshold,
|
||||
preprocessed_image=preprocessed_image,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -830,10 +1007,11 @@ class FluxProDepthNode(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -852,7 +1030,7 @@ class FluxProDepthNode(ComfyNodeABC):
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
auth_token=None,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
control_image = convert_image_to_base64(control_image[:,:,:,:3])
|
||||
@@ -878,9 +1056,9 @@ class FluxProDepthNode(ComfyNodeABC):
|
||||
control_image=control_image,
|
||||
preprocessed_image=preprocessed_image,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -889,6 +1067,8 @@ class FluxProDepthNode(ComfyNodeABC):
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"FluxProUltraImageNode": FluxProUltraImageNode,
|
||||
# "FluxProImageNode": FluxProImageNode,
|
||||
"FluxKontextProImageNode": FluxKontextProImageNode,
|
||||
"FluxKontextMaxImageNode": FluxKontextMaxImageNode,
|
||||
"FluxProExpandNode": FluxProExpandNode,
|
||||
"FluxProFillNode": FluxProFillNode,
|
||||
"FluxProCannyNode": FluxProCannyNode,
|
||||
@@ -899,6 +1079,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image",
|
||||
# "FluxProImageNode": "Flux 1.1 [pro] Image",
|
||||
"FluxKontextProImageNode": "Flux.1 Kontext [pro] Image",
|
||||
"FluxKontextMaxImageNode": "Flux.1 Kontext [max] Image",
|
||||
"FluxProExpandNode": "Flux.1 Expand Image",
|
||||
"FluxProFillNode": "Flux.1 Fill Image",
|
||||
"FluxProCannyNode": "Flux.1 Canny Control Image",
|
||||
|
||||
446
comfy_api_nodes/nodes_gemini.py
Normal file
446
comfy_api_nodes/nodes_gemini.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
API Nodes for Gemini Multimodal LLM Usage via Remote API
|
||||
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||
"""
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Optional, Literal
|
||||
|
||||
import torch
|
||||
|
||||
import folder_paths
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from server import PromptServer
|
||||
from comfy_api_nodes.apis import (
|
||||
GeminiContent,
|
||||
GeminiGenerateContentRequest,
|
||||
GeminiGenerateContentResponse,
|
||||
GeminiInlineData,
|
||||
GeminiPart,
|
||||
GeminiMimeType,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
validate_string,
|
||||
audio_to_base64_string,
|
||||
video_to_base64_string,
|
||||
tensor_to_base64_string,
|
||||
)
|
||||
|
||||
|
||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||
|
||||
|
||||
class GeminiModel(str, Enum):
|
||||
"""
|
||||
Gemini Model Names allowed by comfy-api
|
||||
"""
|
||||
|
||||
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
|
||||
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
|
||||
|
||||
|
||||
def get_gemini_endpoint(
|
||||
model: GeminiModel,
|
||||
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
|
||||
"""
|
||||
Get the API endpoint for a given Gemini model.
|
||||
|
||||
Args:
|
||||
model: The Gemini model to use, either as enum or string value.
|
||||
|
||||
Returns:
|
||||
ApiEndpoint configured for the specific Gemini model.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
model = GeminiModel(model)
|
||||
return ApiEndpoint(
|
||||
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
|
||||
method=HttpMethod.POST,
|
||||
request_model=GeminiGenerateContentRequest,
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
)
|
||||
|
||||
|
||||
class GeminiNode(ComfyNodeABC):
|
||||
"""
|
||||
Node to generate text responses from a Gemini model.
|
||||
|
||||
This node allows users to interact with Google's Gemini AI models, providing
|
||||
multimodal inputs (text, images, audio, video, files) to generate coherent
|
||||
text responses. The node works with the latest Gemini models, handling the
|
||||
API communication and response parsing.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.",
|
||||
},
|
||||
),
|
||||
"model": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "The Gemini model to use for generating responses.",
|
||||
"options": [model.value for model in GeminiModel],
|
||||
"default": GeminiModel.gemini_2_5_pro_preview_05_06.value,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 42,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"images": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
|
||||
},
|
||||
),
|
||||
"audio": (
|
||||
IO.AUDIO,
|
||||
{
|
||||
"tooltip": "Optional audio to use as context for the model.",
|
||||
"default": None,
|
||||
},
|
||||
),
|
||||
"video": (
|
||||
IO.VIDEO,
|
||||
{
|
||||
"tooltip": "Optional video to use as context for the model.",
|
||||
"default": None,
|
||||
},
|
||||
),
|
||||
"files": (
|
||||
"GEMINI_INPUT_FILES",
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses."
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/text/Gemini"
|
||||
API_NODE = True
|
||||
|
||||
def get_parts_from_response(
|
||||
self, response: GeminiGenerateContentResponse
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Extract all parts from the Gemini API response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
List of response parts from the first candidate.
|
||||
"""
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
def get_parts_by_type(
|
||||
self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Filter response parts by their type.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
part_type: Type of parts to extract ("text" or a MIME type).
|
||||
|
||||
Returns:
|
||||
List of response parts matching the requested type.
|
||||
"""
|
||||
parts = []
|
||||
for part in self.get_parts_from_response(response):
|
||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||
parts.append(part)
|
||||
elif (
|
||||
hasattr(part, "inlineData")
|
||||
and part.inlineData
|
||||
and part.inlineData.mimeType == part_type
|
||||
):
|
||||
parts.append(part)
|
||||
# Skip parts that don't match the requested type
|
||||
return parts
|
||||
|
||||
def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str:
|
||||
"""
|
||||
Extract and concatenate all text parts from the response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
Combined text from all text parts in the response.
|
||||
"""
|
||||
parts = self.get_parts_by_type(response, "text")
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert video input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
video_input: Video tensor from ComfyUI.
|
||||
**kwargs: Additional arguments to pass to the conversion function.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded video.
|
||||
"""
|
||||
from comfy_api.util import VideoContainer, VideoCodec
|
||||
base_64_string = video_to_base64_string(
|
||||
video_input,
|
||||
container_format=VideoContainer.MP4,
|
||||
codec=VideoCodec.H264
|
||||
)
|
||||
return [
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.video_mp4,
|
||||
data=base_64_string,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert audio input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded audio.
|
||||
"""
|
||||
audio_parts: list[GeminiPart] = []
|
||||
for batch_index in range(audio_input["waveform"].shape[0]):
|
||||
# Recreate an IO.AUDIO object for the given batch dimension index
|
||||
audio_at_index = {
|
||||
"waveform": audio_input["waveform"][batch_index].unsqueeze(0),
|
||||
"sample_rate": audio_input["sample_rate"],
|
||||
}
|
||||
# Convert to MP3 format for compatibility with Gemini API
|
||||
audio_bytes = audio_to_base64_string(
|
||||
audio_at_index,
|
||||
container_format="mp3",
|
||||
codec_name="libmp3lame",
|
||||
)
|
||||
audio_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.audio_mp3,
|
||||
data=audio_bytes,
|
||||
)
|
||||
)
|
||||
)
|
||||
return audio_parts
|
||||
|
||||
def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert image tensor input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
image_input: Batch of image tensors from ComfyUI.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded images.
|
||||
"""
|
||||
image_parts: list[GeminiPart] = []
|
||||
for image_index in range(image_input.shape[0]):
|
||||
image_as_b64 = tensor_to_base64_string(
|
||||
image_input[image_index].unsqueeze(0)
|
||||
)
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=image_as_b64,
|
||||
)
|
||||
)
|
||||
)
|
||||
return image_parts
|
||||
|
||||
def create_text_part(self, text: str) -> GeminiPart:
|
||||
"""
|
||||
Create a text part for the Gemini API request.
|
||||
|
||||
Args:
|
||||
text: The text content to include in the request.
|
||||
|
||||
Returns:
|
||||
A GeminiPart object with the text content.
|
||||
"""
|
||||
return GeminiPart(text=text)
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: GeminiModel,
|
||||
images: Optional[IO.IMAGE] = None,
|
||||
audio: Optional[IO.AUDIO] = None,
|
||||
video: Optional[IO.VIDEO] = None,
|
||||
files: Optional[list[GeminiPart]] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[str]:
|
||||
# Validate inputs
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
# Create parts list with text prompt as the first part
|
||||
parts: list[GeminiPart] = [self.create_text_part(prompt)]
|
||||
|
||||
# Add other modal parts
|
||||
if images is not None:
|
||||
image_parts = self.create_image_parts(images)
|
||||
parts.extend(image_parts)
|
||||
if audio is not None:
|
||||
parts.extend(self.create_audio_parts(audio))
|
||||
if video is not None:
|
||||
parts.extend(self.create_video_parts(video))
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
# Create response
|
||||
response = SynchronousOperation(
|
||||
endpoint=get_gemini_endpoint(model),
|
||||
request=GeminiGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role="user",
|
||||
parts=parts,
|
||||
)
|
||||
]
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
|
||||
# Get result output
|
||||
output_text = self.get_text_from_response(response)
|
||||
if unique_id and output_text:
|
||||
PromptServer.instance.send_progress_text(output_text, node_id=unique_id)
|
||||
|
||||
return (output_text or "Empty response from Gemini model...",)
|
||||
|
||||
|
||||
class GeminiInputFiles(ComfyNodeABC):
|
||||
"""
|
||||
Loads and formats input files for use with the Gemini API.
|
||||
|
||||
This node allows users to include text (.txt) and PDF (.pdf) files as input
|
||||
context for the Gemini model. Files are converted to the appropriate format
|
||||
required by the API and can be chained together to include multiple files
|
||||
in a single request.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
"""
|
||||
For details about the supported file input types, see:
|
||||
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||
"""
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
input_files = [
|
||||
f
|
||||
for f in os.scandir(input_dir)
|
||||
if f.is_file()
|
||||
and (f.name.endswith(".txt") or f.name.endswith(".pdf"))
|
||||
and f.stat().st_size < GEMINI_MAX_INPUT_FILE_SIZE
|
||||
]
|
||||
input_files = sorted(input_files, key=lambda x: x.name)
|
||||
input_files = [f.name for f in input_files]
|
||||
return {
|
||||
"required": {
|
||||
"file": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
|
||||
"options": input_files,
|
||||
"default": input_files[0] if input_files else None,
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"GEMINI_INPUT_FILES": (
|
||||
"GEMINI_INPUT_FILES",
|
||||
{
|
||||
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
|
||||
"default": None,
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes."
|
||||
RETURN_TYPES = ("GEMINI_INPUT_FILES",)
|
||||
FUNCTION = "prepare_files"
|
||||
CATEGORY = "api node/text/Gemini"
|
||||
|
||||
def create_file_part(self, file_path: str) -> GeminiPart:
|
||||
mime_type = (
|
||||
GeminiMimeType.pdf
|
||||
if file_path.endswith(".pdf")
|
||||
else GeminiMimeType.text_plain
|
||||
)
|
||||
# Use base64 string directly, not the data URI
|
||||
with open(file_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
import base64
|
||||
base64_str = base64.b64encode(file_content).decode("utf-8")
|
||||
|
||||
return GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=mime_type,
|
||||
data=base64_str,
|
||||
)
|
||||
)
|
||||
|
||||
def prepare_files(
|
||||
self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = []
|
||||
) -> tuple[list[GeminiPart]]:
|
||||
"""
|
||||
Loads and formats input files for Gemini API.
|
||||
"""
|
||||
file_path = folder_paths.get_annotated_filepath(file)
|
||||
input_file_content = self.create_file_part(file_path)
|
||||
files = [input_file_content] + GEMINI_INPUT_FILES
|
||||
return (files,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"GeminiNode": GeminiNode,
|
||||
"GeminiInputFiles": GeminiInputFiles,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"GeminiNode": "Google Gemini",
|
||||
"GeminiInputFiles": "Gemini Input Files",
|
||||
}
|
||||
@@ -23,6 +23,7 @@ from comfy_api_nodes.apinode_utils import (
|
||||
bytesio_to_image_tensor,
|
||||
resize_mask_to_image,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
V1_V1_RES_MAP = {
|
||||
"Auto":"AUTO",
|
||||
@@ -232,11 +233,22 @@ def download_and_process_images(image_urls):
|
||||
return stacked_tensors
|
||||
|
||||
|
||||
def display_image_urls_on_node(image_urls, node_id):
|
||||
if node_id and image_urls:
|
||||
if len(image_urls) == 1:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generated Image URL:\n{image_urls[0]}", node_id
|
||||
)
|
||||
else:
|
||||
urls_text = "Generated Image URLs:\n" + "\n".join(
|
||||
f"{i+1}. {url}" for i, url in enumerate(image_urls)
|
||||
)
|
||||
PromptServer.instance.send_progress_text(urls_text, node_id)
|
||||
|
||||
|
||||
class IdeogramV1(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously using the Ideogram V1 model.
|
||||
|
||||
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
|
||||
Generates images using the Ideogram V1 model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -303,7 +315,11 @@ class IdeogramV1(ComfyNodeABC):
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
@@ -321,7 +337,8 @@ class IdeogramV1(ComfyNodeABC):
|
||||
seed=0,
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
auth_token=None,
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Determine the model based on turbo setting
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
@@ -347,7 +364,7 @@ class IdeogramV1(ComfyNodeABC):
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
)
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
@@ -360,14 +377,13 @@ class IdeogramV1(ComfyNodeABC):
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
|
||||
|
||||
class IdeogramV2(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously using the Ideogram V2 model.
|
||||
|
||||
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
|
||||
Generates images using the Ideogram V2 model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -458,7 +474,11 @@ class IdeogramV2(ComfyNodeABC):
|
||||
# },
|
||||
#),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
@@ -479,7 +499,8 @@ class IdeogramV2(ComfyNodeABC):
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
color_palette="",
|
||||
auth_token=None,
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
resolution = V1_V1_RES_MAP.get(resolution, None)
|
||||
@@ -519,7 +540,7 @@ class IdeogramV2(ComfyNodeABC):
|
||||
color_palette=color_palette if color_palette else None,
|
||||
)
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
@@ -532,14 +553,12 @@ class IdeogramV2(ComfyNodeABC):
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
|
||||
class IdeogramV3(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously using the Ideogram V3 model.
|
||||
|
||||
Supports both regular image generation from text prompts and image editing with mask.
|
||||
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
|
||||
Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -621,7 +640,11 @@ class IdeogramV3(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
@@ -641,7 +664,8 @@ class IdeogramV3(ComfyNodeABC):
|
||||
seed=0,
|
||||
num_images=1,
|
||||
rendering_speed="BALANCED",
|
||||
auth_token=None,
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Check if both image and mask are provided for editing mode
|
||||
if image is not None and mask is not None:
|
||||
@@ -705,7 +729,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
"mask": mask_binary,
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
elif image is not None or mask is not None:
|
||||
@@ -746,7 +770,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=gen_request,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Execute the operation and process response
|
||||
@@ -760,6 +784,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
|
||||
|
||||
@@ -774,4 +799,3 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"IdeogramV2": "Ideogram V2",
|
||||
"IdeogramV3": "Ideogram V3",
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ For source of truth on the allowed permutations of request fields, please refere
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional, TypeVar, Any
|
||||
from collections.abc import Callable
|
||||
import math
|
||||
import logging
|
||||
|
||||
@@ -64,6 +65,12 @@ from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_image_tensor,
|
||||
)
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
from comfy_api_nodes.util.validation_utils import (
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio,
|
||||
validate_video_dimensions,
|
||||
validate_video_duration,
|
||||
)
|
||||
from comfy_api.input.basic_types import AudioInput
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
@@ -79,13 +86,20 @@ PATH_CHARACTER_IMAGE = f"/proxy/kling/{KLING_API_VERSION}/images/generations"
|
||||
PATH_VIRTUAL_TRY_ON = f"/proxy/kling/{KLING_API_VERSION}/images/kolors-virtual-try-on"
|
||||
PATH_IMAGE_GENERATIONS = f"/proxy/kling/{KLING_API_VERSION}/images/generations"
|
||||
|
||||
|
||||
MAX_PROMPT_LENGTH_T2V = 2500
|
||||
MAX_PROMPT_LENGTH_I2V = 500
|
||||
MAX_PROMPT_LENGTH_IMAGE_GEN = 500
|
||||
MAX_NEGATIVE_PROMPT_LENGTH_IMAGE_GEN = 200
|
||||
MAX_PROMPT_LENGTH_LIP_SYNC = 120
|
||||
|
||||
AVERAGE_DURATION_T2V = 319
|
||||
AVERAGE_DURATION_I2V = 164
|
||||
AVERAGE_DURATION_LIP_SYNC = 455
|
||||
AVERAGE_DURATION_VIRTUAL_TRY_ON = 19
|
||||
AVERAGE_DURATION_IMAGE_GEN = 32
|
||||
AVERAGE_DURATION_VIDEO_EFFECTS = 320
|
||||
AVERAGE_DURATION_VIDEO_EXTEND = 320
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
@@ -95,7 +109,13 @@ class KlingApiError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def poll_until_finished(auth_token: str, api_endpoint: ApiEndpoint[Any, R]) -> R:
|
||||
def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
@@ -108,7 +128,10 @@ def poll_until_finished(auth_token: str, api_endpoint: ApiEndpoint[Any, R]) -> R
|
||||
if response.data and response.data.task_status
|
||||
else None
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=result_url_extractor,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
).execute()
|
||||
|
||||
|
||||
@@ -184,6 +207,18 @@ def validate_image_result_response(response) -> None:
|
||||
raise KlingApiError(error_msg)
|
||||
|
||||
|
||||
def validate_input_image(image: torch.Tensor) -> None:
|
||||
"""
|
||||
Validates the input image adheres to the expectations of the Kling API:
|
||||
- The image resolution should not be less than 300*300px
|
||||
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
|
||||
|
||||
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
|
||||
"""
|
||||
validate_image_dimensions(image, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5)
|
||||
|
||||
|
||||
def get_camera_control_input_config(
|
||||
tooltip: str, default: float = 0.0
|
||||
) -> tuple[IO, InputTypeOptions]:
|
||||
@@ -200,7 +235,9 @@ def get_camera_control_input_config(
|
||||
|
||||
|
||||
def get_video_from_response(response) -> KlingVideoResult:
|
||||
"""Returns the first video object from the Kling video generation task result."""
|
||||
"""Returns the first video object from the Kling video generation task result.
|
||||
Will raise an error if the response is not valid.
|
||||
"""
|
||||
video = response.data.task_result.videos[0]
|
||||
logging.info(
|
||||
"Kling task %s succeeded. Video URL: %s", response.data.task_id, video.url
|
||||
@@ -208,12 +245,37 @@ def get_video_from_response(response) -> KlingVideoResult:
|
||||
return video
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
"""Returns the first video url from the Kling video generation task result.
|
||||
Will not raise an error if the response is not valid.
|
||||
"""
|
||||
if response and is_valid_video_response(response):
|
||||
return str(get_video_from_response(response).url)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_images_from_response(response) -> list[KlingImageResult]:
|
||||
"""Returns the list of image objects from the Kling image generation task result.
|
||||
Will raise an error if the response is not valid.
|
||||
"""
|
||||
images = response.data.task_result.images
|
||||
logging.info("Kling task %s succeeded. Images: %s", response.data.task_id, images)
|
||||
return images
|
||||
|
||||
|
||||
def get_images_urls_from_response(response) -> Optional[str]:
|
||||
"""Returns the list of image urls from the Kling image generation task result.
|
||||
Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls.
|
||||
"""
|
||||
if response and is_valid_image_response(response):
|
||||
images = get_images_from_response(response)
|
||||
image_urls = [str(image.url) for image in images]
|
||||
return "\n".join(image_urls)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def video_result_to_node_output(
|
||||
video: KlingVideoResult,
|
||||
) -> tuple[VideoFromFile, str, str]:
|
||||
@@ -285,6 +347,7 @@ class KlingCameraControls(KlingNodeBase):
|
||||
RETURN_TYPES = ("CAMERA_CONTROL",)
|
||||
RETURN_NAMES = ("camera_control",)
|
||||
FUNCTION = "main"
|
||||
API_NODE = False # This is just a helper node, it doesn't make an API call
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(
|
||||
@@ -391,22 +454,31 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
DESCRIPTION = "Kling Text to Video Node"
|
||||
|
||||
def get_response(self, task_id: str, auth_token: str) -> KlingText2VideoResponse:
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingText2VideoResponse:
|
||||
return poll_until_finished(
|
||||
auth_token,
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_TEXT_TO_VIDEO}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingText2VideoResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
@@ -419,7 +491,8 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
model_name: Optional[str] = None,
|
||||
duration: Optional[str] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile, str, str]:
|
||||
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
|
||||
if model_name is None:
|
||||
@@ -441,14 +514,16 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
|
||||
camera_control=camera_control,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
|
||||
task_id = task_creation_response.data.task_id
|
||||
final_response = self.get_response(task_id, auth_token)
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
@@ -495,7 +570,11 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text."
|
||||
@@ -507,7 +586,8 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||
cfg_scale: float,
|
||||
aspect_ratio: str,
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
model_name=KlingVideoGenModelName.kling_v1,
|
||||
@@ -518,7 +598,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
camera_control=camera_control,
|
||||
auth_token=auth_token,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -530,7 +610,10 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
return {
|
||||
"required": {
|
||||
"start_frame": model_field_to_node_input(
|
||||
IO.IMAGE, KlingImage2VideoRequest, "image"
|
||||
IO.IMAGE,
|
||||
KlingImage2VideoRequest,
|
||||
"image",
|
||||
tooltip="The reference image used to generate the video.",
|
||||
),
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True
|
||||
@@ -574,22 +657,31 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
enum_type=KlingVideoGenDuration,
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
DESCRIPTION = "Kling Image to Video Node"
|
||||
|
||||
def get_response(self, task_id: str, auth_token: str) -> KlingImage2VideoResponse:
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingImage2VideoResponse:
|
||||
return poll_until_finished(
|
||||
auth_token,
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=KlingImage2VideoRequest,
|
||||
response_model=KlingImage2VideoResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
@@ -604,12 +696,14 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
duration: str,
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
end_frame: Optional[torch.Tensor] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V)
|
||||
validate_input_image(start_frame)
|
||||
|
||||
if camera_control is not None:
|
||||
# Camera control type for image 2 video is always simple
|
||||
# Camera control type for image 2 video is always `simple`
|
||||
camera_control.type = KlingCameraControlType.simple
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
@@ -631,18 +725,19 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
cfg_scale=cfg_scale,
|
||||
mode=KlingVideoGenMode(mode),
|
||||
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
|
||||
duration=KlingVideoGenDuration(duration),
|
||||
camera_control=camera_control,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(task_id, auth_token)
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
@@ -692,7 +787,11 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image."
|
||||
@@ -705,7 +804,8 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||
cfg_scale: float,
|
||||
aspect_ratio: str,
|
||||
camera_control: KlingCameraControl,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
model_name=KlingVideoGenModelName.kling_v1_5,
|
||||
@@ -717,7 +817,8 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
camera_control=camera_control,
|
||||
auth_token=auth_token,
|
||||
unique_id=unique_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -785,7 +886,11 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last."
|
||||
@@ -799,7 +904,8 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
|
||||
cfg_scale: float,
|
||||
aspect_ratio: str,
|
||||
mode: str,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[
|
||||
mode
|
||||
@@ -814,7 +920,8 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=duration,
|
||||
end_frame=end_frame,
|
||||
auth_token=auth_token,
|
||||
unique_id=unique_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -844,22 +951,31 @@ class KlingVideoExtendNode(KlingNodeBase):
|
||||
IO.STRING, KlingVideoExtendRequest, "video_id", forceInput=True
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes."
|
||||
|
||||
def get_response(self, task_id: str, auth_token: str) -> KlingVideoExtendResponse:
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingVideoExtendResponse:
|
||||
return poll_until_finished(
|
||||
auth_token,
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_EXTEND}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingVideoExtendResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
@@ -868,7 +984,8 @@ class KlingVideoExtendNode(KlingNodeBase):
|
||||
negative_prompt: str,
|
||||
cfg_scale: float,
|
||||
video_id: str,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile, str, str]:
|
||||
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
|
||||
initial_operation = SynchronousOperation(
|
||||
@@ -884,14 +1001,16 @@ class KlingVideoExtendNode(KlingNodeBase):
|
||||
cfg_scale=cfg_scale,
|
||||
video_id=video_id,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(task_id, auth_token)
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
@@ -904,15 +1023,20 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
||||
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
|
||||
def get_response(self, task_id: str, auth_token: str) -> KlingVideoEffectsResponse:
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingVideoEffectsResponse:
|
||||
return poll_until_finished(
|
||||
auth_token,
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_EFFECTS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingVideoEffectsResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
@@ -924,7 +1048,8 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
||||
image_1: torch.Tensor,
|
||||
image_2: Optional[torch.Tensor] = None,
|
||||
mode: Optional[KlingVideoGenMode] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if dual_character:
|
||||
request_input_field = KlingDualCharacterEffectInput(
|
||||
@@ -954,14 +1079,16 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
||||
effect_scene=effect_scene,
|
||||
input=request_input_field,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(task_id, auth_token)
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
@@ -1002,7 +1129,11 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||
enum_type=KlingVideoGenDuration,
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite."
|
||||
@@ -1017,7 +1148,8 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||
model_name: KlingCharacterEffectModelName,
|
||||
mode: KlingVideoGenMode,
|
||||
duration: KlingVideoGenDuration,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
video, _, duration = super().api_call(
|
||||
dual_character=True,
|
||||
@@ -1027,10 +1159,12 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||
duration=duration,
|
||||
image_1=image_left,
|
||||
image_2=image_right,
|
||||
auth_token=auth_token,
|
||||
unique_id=unique_id,
|
||||
**kwargs,
|
||||
)
|
||||
return video, duration
|
||||
|
||||
|
||||
class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
||||
"""Kling Single Image Video Effect Node"""
|
||||
|
||||
@@ -1063,7 +1197,11 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
||||
enum_type=KlingVideoGenDuration,
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene."
|
||||
@@ -1074,7 +1212,8 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
||||
effect_scene: KlingSingleImageEffectsScene,
|
||||
model_name: KlingSingleImageEffectModelName,
|
||||
duration: KlingVideoGenDuration,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
dual_character=False,
|
||||
@@ -1082,7 +1221,8 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
||||
model_name=model_name,
|
||||
duration=duration,
|
||||
image_1=image,
|
||||
auth_token=auth_token,
|
||||
unique_id=unique_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1092,6 +1232,17 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
|
||||
def validate_lip_sync_video(self, video: VideoInput):
|
||||
"""
|
||||
Validates the input video adheres to the expectations of the Kling Lip Sync API:
|
||||
- Video length does not exceed 10s and is not shorter than 2s
|
||||
- Length and width dimensions should both be between 720px and 1920px
|
||||
|
||||
See: https://app.klingai.com/global/dev/document-api/apiReference/model/videoTolip
|
||||
"""
|
||||
validate_video_dimensions(video, 720, 1920)
|
||||
validate_video_duration(video, 2, 10)
|
||||
|
||||
def validate_text(self, text: str):
|
||||
if not text:
|
||||
raise ValueError("Text is required")
|
||||
@@ -1100,16 +1251,21 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters."
|
||||
)
|
||||
|
||||
def get_response(self, task_id: str, auth_token: str) -> KlingLipSyncResponse:
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingLipSyncResponse:
|
||||
"""Polls the Kling API endpoint until the task reaches a terminal state."""
|
||||
return poll_until_finished(
|
||||
auth_token,
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_LIP_SYNC}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingLipSyncResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_LIP_SYNC,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
@@ -1121,18 +1277,20 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
text: Optional[str] = None,
|
||||
voice_speed: Optional[float] = None,
|
||||
voice_id: Optional[str] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile, str, str]:
|
||||
if text:
|
||||
self.validate_text(text)
|
||||
self.validate_lip_sync_video(video)
|
||||
|
||||
# Upload video to Comfy API and get download URL
|
||||
video_url = upload_video_to_comfyapi(video, auth_token)
|
||||
video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs)
|
||||
logging.info("Uploaded video to Comfy API. URL: %s", video_url)
|
||||
|
||||
# Upload the audio file to Comfy API and get download URL
|
||||
if audio:
|
||||
audio_url = upload_audio_to_comfyapi(audio, auth_token)
|
||||
audio_url = upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
|
||||
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
|
||||
else:
|
||||
audio_url = None
|
||||
@@ -1156,14 +1314,16 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
voice_id=voice_id,
|
||||
),
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(task_id, auth_token)
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
@@ -1186,24 +1346,30 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
|
||||
enum_type=KlingLipSyncVoiceLanguage,
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file."
|
||||
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
audio: AudioInput,
|
||||
voice_language: str,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
video=video,
|
||||
audio=audio,
|
||||
voice_language=voice_language,
|
||||
mode="audio2video",
|
||||
auth_token=auth_token,
|
||||
unique_id=unique_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1292,10 +1458,14 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
|
||||
IO.FLOAT, KlingLipSyncInputObject, "voice_speed", slider=True
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt."
|
||||
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
@@ -1303,7 +1473,8 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
|
||||
text: str,
|
||||
voice: str,
|
||||
voice_speed: float,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice]
|
||||
return super().api_call(
|
||||
@@ -1313,7 +1484,8 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
|
||||
voice_id=voice_id,
|
||||
voice_speed=voice_speed,
|
||||
mode="text2video",
|
||||
auth_token=auth_token,
|
||||
unique_id=unique_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1350,22 +1522,29 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
||||
enum_type=KlingVirtualTryOnModelName,
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human."
|
||||
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background."
|
||||
|
||||
def get_response(
|
||||
self, task_id: str, auth_token: Optional[str] = None
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingVirtualTryOnResponse:
|
||||
return poll_until_finished(
|
||||
auth_token,
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingVirtualTryOnResponse,
|
||||
),
|
||||
result_url_extractor=get_images_urls_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
@@ -1373,7 +1552,8 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
||||
human_image: torch.Tensor,
|
||||
cloth_image: torch.Tensor,
|
||||
model_name: KlingVirtualTryOnModelName,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -1387,14 +1567,16 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
||||
cloth_image=tensor_to_base64_string(cloth_image),
|
||||
model_name=model_name,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(task_id, auth_token)
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_image_result_response(final_response)
|
||||
|
||||
images = get_images_from_response(final_response)
|
||||
@@ -1462,22 +1644,32 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
"optional": {
|
||||
"image": (IO.IMAGE, {}),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image."
|
||||
|
||||
def get_response(
|
||||
self, task_id: str, auth_token: Optional[str] = None
|
||||
self,
|
||||
task_id: str,
|
||||
auth_kwargs: Optional[dict[str, str]],
|
||||
node_id: Optional[str] = None,
|
||||
) -> KlingImageGenerationsResponse:
|
||||
return poll_until_finished(
|
||||
auth_token,
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_IMAGE_GENERATIONS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingImageGenerationsResponse,
|
||||
),
|
||||
result_url_extractor=get_images_urls_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_IMAGE_GEN,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
@@ -1491,7 +1683,8 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
n: int,
|
||||
aspect_ratio: KlingImageGenAspectRatio,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.validate_prompt(prompt, negative_prompt)
|
||||
|
||||
@@ -1516,14 +1709,16 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
n=n,
|
||||
aspect_ratio=aspect_ratio,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(task_id, auth_token)
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_image_result_response(final_response)
|
||||
|
||||
images = get_images_from_response(final_response)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis.luma_api import (
|
||||
@@ -34,11 +36,20 @@ from comfy_api_nodes.apinode_utils import (
|
||||
process_image_response,
|
||||
validate_string,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from io import BytesIO
|
||||
|
||||
LUMA_T2V_AVERAGE_DURATION = 105
|
||||
LUMA_I2V_AVERAGE_DURATION = 100
|
||||
|
||||
def image_result_url_extractor(response: LumaGeneration):
|
||||
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
|
||||
|
||||
def video_result_url_extractor(response: LumaGeneration):
|
||||
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
|
||||
|
||||
class LumaReferenceNode(ComfyNodeABC):
|
||||
"""
|
||||
@@ -201,6 +212,8 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -214,7 +227,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
image_luma_ref: LumaReferenceChain = None,
|
||||
style_image: torch.Tensor = None,
|
||||
character_image: torch.Tensor = None,
|
||||
auth_token=None,
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=True, min_length=3)
|
||||
@@ -222,19 +235,19 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
api_image_ref = None
|
||||
if image_luma_ref is not None:
|
||||
api_image_ref = self._convert_luma_refs(
|
||||
image_luma_ref, max_refs=4, auth_token=auth_token
|
||||
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle style_luma_ref
|
||||
api_style_ref = None
|
||||
if style_image is not None:
|
||||
api_style_ref = self._convert_style_image(
|
||||
style_image, weight=style_image_weight, auth_token=auth_token
|
||||
style_image, weight=style_image_weight, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle character_ref images
|
||||
character_ref = None
|
||||
if character_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
character_image, max_images=4, auth_token=auth_token
|
||||
character_image, max_images=4, auth_kwargs=kwargs,
|
||||
)
|
||||
character_ref = LumaCharacterRef(
|
||||
identity0=LumaImageIdentity(images=download_urls)
|
||||
@@ -255,7 +268,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
style_ref=api_style_ref,
|
||||
character_ref=character_ref,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
@@ -269,7 +282,9 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
auth_token=auth_token,
|
||||
result_url_extractor=image_result_url_extractor,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
@@ -278,13 +293,13 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
return (img,)
|
||||
|
||||
def _convert_luma_refs(
|
||||
self, luma_ref: LumaReferenceChain, max_refs: int, auth_token=None
|
||||
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
luma_urls = []
|
||||
ref_count = 0
|
||||
for ref in luma_ref.refs:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
ref.image, max_images=1, auth_token=auth_token
|
||||
ref.image, max_images=1, auth_kwargs=auth_kwargs
|
||||
)
|
||||
luma_urls.append(download_urls[0])
|
||||
ref_count += 1
|
||||
@@ -293,12 +308,12 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
||||
|
||||
def _convert_style_image(
|
||||
self, style_image: torch.Tensor, weight: float, auth_token=None
|
||||
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
chain = LumaReferenceChain(
|
||||
first_ref=LumaReference(image=style_image, weight=weight)
|
||||
)
|
||||
return self._convert_luma_refs(chain, max_refs=1, auth_token=auth_token)
|
||||
return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||
|
||||
|
||||
class LumaImageModifyNode(ComfyNodeABC):
|
||||
@@ -350,6 +365,8 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -360,12 +377,12 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
image: torch.Tensor,
|
||||
image_weight: float,
|
||||
seed,
|
||||
auth_token=None,
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
# first, upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_token=auth_token
|
||||
image, max_images=1, auth_kwargs=kwargs,
|
||||
)
|
||||
image_url = download_urls[0]
|
||||
# next, make Luma call with download url provided
|
||||
@@ -383,7 +400,7 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
|
||||
),
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
@@ -397,7 +414,9 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
auth_token=auth_token,
|
||||
result_url_extractor=image_result_url_extractor,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
@@ -470,6 +489,8 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -483,7 +504,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
loop: bool,
|
||||
seed,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
auth_token=None,
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False, min_length=3)
|
||||
@@ -506,10 +527,13 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
loop=loop,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
@@ -520,7 +544,10 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
auth_token=auth_token,
|
||||
result_url_extractor=video_result_url_extractor,
|
||||
node_id=unique_id,
|
||||
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
@@ -594,6 +621,8 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -608,14 +637,14 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
auth_token=None,
|
||||
unique_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
if first_image is None and last_image is None:
|
||||
raise Exception(
|
||||
"At least one of first_image and last_image requires an input."
|
||||
)
|
||||
keyframes = self._convert_to_keyframes(first_image, last_image, auth_token)
|
||||
keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
@@ -636,10 +665,13 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
keyframes=keyframes,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
@@ -650,7 +682,10 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
auth_token=auth_token,
|
||||
result_url_extractor=video_result_url_extractor,
|
||||
node_id=unique_id,
|
||||
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
@@ -661,7 +696,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
self,
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
auth_token=None,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
):
|
||||
if first_image is None and last_image is None:
|
||||
return None
|
||||
@@ -669,12 +704,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
frame1 = None
|
||||
if first_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
first_image, max_images=1, auth_token=auth_token
|
||||
first_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
||||
if last_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
last_image, max_images=1, auth_token=auth_token
|
||||
last_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
||||
return LumaKeyframes(frame0=frame0, frame1=frame1)
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
@@ -20,16 +24,19 @@ from comfy_api_nodes.apinode_utils import (
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
import torch
|
||||
import logging
|
||||
|
||||
I2V_AVERAGE_DURATION = 114
|
||||
T2V_AVERAGE_DURATION = 234
|
||||
|
||||
class MinimaxTextToVideoNode:
|
||||
"""
|
||||
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
AVERAGE_DURATION = T2V_AVERAGE_DURATION
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@@ -67,6 +74,8 @@ class MinimaxTextToVideoNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -84,7 +93,8 @@ class MinimaxTextToVideoNode:
|
||||
model="T2V-01",
|
||||
image: torch.Tensor=None, # used for ImageToVideo
|
||||
subject: torch.Tensor=None, # used for SubjectToVideo
|
||||
auth_token=None,
|
||||
unique_id: Union[str, None]=None,
|
||||
**kwargs,
|
||||
):
|
||||
'''
|
||||
Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments.
|
||||
@@ -94,12 +104,12 @@ class MinimaxTextToVideoNode:
|
||||
# upload image, if passed in
|
||||
image_url = None
|
||||
if image is not None:
|
||||
image_url = upload_images_to_comfyapi(image, max_images=1, auth_token=auth_token)[0]
|
||||
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0]
|
||||
|
||||
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
||||
subject_reference = None
|
||||
if subject is not None:
|
||||
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_token=auth_token)[0]
|
||||
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0]
|
||||
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
||||
|
||||
|
||||
@@ -118,7 +128,7 @@ class MinimaxTextToVideoNode:
|
||||
subject_reference=subject_reference,
|
||||
prompt_optimizer=None,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response = video_generate_operation.execute()
|
||||
|
||||
@@ -137,7 +147,9 @@ class MinimaxTextToVideoNode:
|
||||
completed_statuses=["Success"],
|
||||
failed_statuses=["Fail"],
|
||||
status_extractor=lambda x: x.status.value,
|
||||
auth_token=auth_token,
|
||||
estimated_duration=self.AVERAGE_DURATION,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_result = video_generate_operation.execute()
|
||||
|
||||
@@ -153,7 +165,7 @@ class MinimaxTextToVideoNode:
|
||||
query_params={"file_id": int(file_id)},
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
file_result = file_retrieve_operation.execute()
|
||||
|
||||
@@ -163,6 +175,12 @@ class MinimaxTextToVideoNode:
|
||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||
)
|
||||
logging.info(f"Generated video URL: {file_url}")
|
||||
if unique_id:
|
||||
if hasattr(file_result.file, "backup_download_url"):
|
||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||
else:
|
||||
message = f"Result URL: {file_url}"
|
||||
PromptServer.instance.send_progress_text(message, unique_id)
|
||||
|
||||
video_io = download_url_to_bytesio(file_url)
|
||||
if video_io is None:
|
||||
@@ -177,6 +195,8 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
|
||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
AVERAGE_DURATION = I2V_AVERAGE_DURATION
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@@ -221,6 +241,8 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -237,6 +259,8 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
|
||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
AVERAGE_DURATION = T2V_AVERAGE_DURATION
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@@ -279,6 +303,8 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1,29 +1,86 @@
|
||||
import io
|
||||
from typing import TypedDict, Optional
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from inspect import cleandoc
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from server import PromptServer
|
||||
import folder_paths
|
||||
|
||||
|
||||
from comfy_api_nodes.apis import (
|
||||
OpenAIImageGenerationRequest,
|
||||
OpenAIImageEditRequest,
|
||||
OpenAIImageGenerationResponse,
|
||||
OpenAICreateResponse,
|
||||
OpenAIResponse,
|
||||
CreateModelResponseProperties,
|
||||
Item,
|
||||
Includable,
|
||||
OutputContent,
|
||||
InputImageContent,
|
||||
Detail,
|
||||
InputTextContent,
|
||||
InputMessage,
|
||||
InputMessageContentList,
|
||||
InputContent,
|
||||
InputFileContent,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
validate_and_cast_response,
|
||||
validate_string,
|
||||
tensor_to_base64_string,
|
||||
text_filepath_to_data_uri,
|
||||
)
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
|
||||
|
||||
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
|
||||
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
|
||||
|
||||
|
||||
class HistoryEntry(TypedDict):
|
||||
"""Type definition for a single history entry in the chat."""
|
||||
|
||||
prompt: str
|
||||
response: str
|
||||
response_id: str
|
||||
timestamp: float
|
||||
|
||||
|
||||
class ChatHistory(TypedDict):
|
||||
"""Type definition for the chat history dictionary."""
|
||||
|
||||
__annotations__: dict[str, list[HistoryEntry]]
|
||||
|
||||
|
||||
class SupportedOpenAIModel(str, Enum):
|
||||
o4_mini = "o4-mini"
|
||||
o1 = "o1"
|
||||
o3 = "o3"
|
||||
o1_pro = "o1-pro"
|
||||
gpt_4o = "gpt-4o"
|
||||
gpt_4_1 = "gpt-4.1"
|
||||
gpt_4_1_mini = "gpt-4.1-mini"
|
||||
gpt_4_1_nano = "gpt-4.1-nano"
|
||||
|
||||
|
||||
class OpenAIDalle2(ComfyNodeABC):
|
||||
"""
|
||||
@@ -93,7 +150,11 @@ class OpenAIDalle2(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
@@ -110,7 +171,8 @@ class OpenAIDalle2(ComfyNodeABC):
|
||||
mask=None,
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
auth_token=None,
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "dall-e-2"
|
||||
@@ -168,12 +230,12 @@ class OpenAIDalle2(ComfyNodeABC):
|
||||
else None
|
||||
),
|
||||
content_type=content_type,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response)
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
@@ -236,7 +298,11 @@ class OpenAIDalle3(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
@@ -252,7 +318,8 @@ class OpenAIDalle3(ComfyNodeABC):
|
||||
style="natural",
|
||||
quality="standard",
|
||||
size="1024x1024",
|
||||
auth_token=None,
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "dall-e-3"
|
||||
@@ -273,12 +340,12 @@ class OpenAIDalle3(ComfyNodeABC):
|
||||
style=style,
|
||||
seed=seed,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response)
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
@@ -366,7 +433,11 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
@@ -385,12 +456,13 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
mask=None,
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
auth_token=None,
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "gpt-image-1"
|
||||
path = "/proxy/openai/images/generations"
|
||||
content_type="application/json"
|
||||
content_type = "application/json"
|
||||
request_class = OpenAIImageGenerationRequest
|
||||
img_binaries = []
|
||||
mask_binary = None
|
||||
@@ -399,7 +471,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
if image is not None:
|
||||
path = "/proxy/openai/images/edits"
|
||||
request_class = OpenAIImageEditRequest
|
||||
content_type ="multipart/form-data"
|
||||
content_type = "multipart/form-data"
|
||||
|
||||
batch_size = image.shape[0]
|
||||
|
||||
@@ -462,26 +534,475 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
),
|
||||
files=files if files else None,
|
||||
content_type=content_type,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response)
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
class OpenAITextNode(ComfyNodeABC):
|
||||
"""
|
||||
Base class for OpenAI text generation nodes.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/text/OpenAI"
|
||||
API_NODE = True
|
||||
|
||||
|
||||
class OpenAIChatNode(OpenAITextNode):
|
||||
"""
|
||||
Node to generate text responses from an OpenAI model.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the chat node with a new session ID and empty history."""
|
||||
self.current_session_id: str = str(uuid.uuid4())
|
||||
self.history: dict[str, list[HistoryEntry]] = {}
|
||||
self.previous_response_id: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text inputs to the model, used to generate a response.",
|
||||
},
|
||||
),
|
||||
"persist_context": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": True,
|
||||
"tooltip": "Persist chat context between calls (multi-turn conversation)",
|
||||
},
|
||||
),
|
||||
"model": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
OpenAICreateResponse,
|
||||
"model",
|
||||
enum_type=SupportedOpenAIModel,
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"images": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
|
||||
},
|
||||
),
|
||||
"files": (
|
||||
"OPENAI_INPUT_FILES",
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the OpenAI Chat Input Files node.",
|
||||
},
|
||||
),
|
||||
"advanced_options": (
|
||||
"OPENAI_CHAT_CONFIG",
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional configuration for the model. Accepts inputs from the OpenAI Chat Advanced Options node.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Generate text responses from an OpenAI model."
|
||||
|
||||
def get_result_response(
|
||||
self,
|
||||
response_id: str,
|
||||
include: Optional[list[Includable]] = None,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
) -> OpenAIResponse:
|
||||
"""
|
||||
Retrieve a model response with the given ID from the OpenAI API.
|
||||
|
||||
Args:
|
||||
response_id (str): The ID of the response to retrieve.
|
||||
include (Optional[List[Includable]]): Additional fields to include
|
||||
in the response. See the `include` parameter for Response
|
||||
creation above for more information.
|
||||
|
||||
"""
|
||||
return PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"{RESPONSES_ENDPOINT}/{response_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=OpenAIResponse,
|
||||
query_params={"include": include},
|
||||
),
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=["failed"],
|
||||
status_extractor=lambda response: response.status,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
|
||||
def get_message_content_from_response(
|
||||
self, response: OpenAIResponse
|
||||
) -> list[OutputContent]:
|
||||
"""Extract message content from the API response."""
|
||||
for output in response.output:
|
||||
if output.root.type == "message":
|
||||
return output.root.content
|
||||
raise TypeError("No output message found in response")
|
||||
|
||||
def get_text_from_message_content(
|
||||
self, message_content: list[OutputContent]
|
||||
) -> str:
|
||||
"""Extract text content from message content."""
|
||||
for content_item in message_content:
|
||||
if content_item.root.type == "output_text":
|
||||
return str(content_item.root.text)
|
||||
return "No text output found in response"
|
||||
|
||||
def get_history_text(self, session_id: str) -> str:
|
||||
"""Convert the entire history for a given session to JSON string."""
|
||||
return json.dumps(self.history[session_id])
|
||||
|
||||
def display_history_on_node(self, session_id: str, node_id: str) -> None:
|
||||
"""Display formatted chat history on the node UI."""
|
||||
render_spec = {
|
||||
"node_id": node_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": self.get_history_text(session_id),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
|
||||
def add_to_history(
|
||||
self, session_id: str, prompt: str, output_text: str, response_id: str
|
||||
) -> None:
|
||||
"""Add a new entry to the chat history."""
|
||||
if session_id not in self.history:
|
||||
self.history[session_id] = []
|
||||
self.history[session_id].append(
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": response_id,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
)
|
||||
|
||||
def parse_output_text_from_response(self, response: OpenAIResponse) -> str:
|
||||
"""Extract text output from the API response."""
|
||||
message_contents = self.get_message_content_from_response(response)
|
||||
return self.get_text_from_message_content(message_contents)
|
||||
|
||||
def generate_new_session_id(self) -> str:
|
||||
"""Generate a new unique session ID."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def get_session_id(self, persist_context: bool) -> str:
|
||||
"""Get the current or generate a new session ID based on context persistence."""
|
||||
return (
|
||||
self.current_session_id
|
||||
if persist_context
|
||||
else self.generate_new_session_id()
|
||||
)
|
||||
|
||||
def tensor_to_input_image_content(
|
||||
self, image: torch.Tensor, detail_level: Detail = "auto"
|
||||
) -> InputImageContent:
|
||||
"""Convert a tensor to an input image content object."""
|
||||
return InputImageContent(
|
||||
detail=detail_level,
|
||||
image_url=f"data:image/png;base64,{tensor_to_base64_string(image)}",
|
||||
type="input_image",
|
||||
)
|
||||
|
||||
def create_input_message_contents(
|
||||
self,
|
||||
prompt: str,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
files: Optional[list[InputFileContent]] = None,
|
||||
) -> InputMessageContentList:
|
||||
"""Create a list of input message contents from prompt and optional image."""
|
||||
content_list: list[InputContent] = [
|
||||
InputTextContent(text=prompt, type="input_text"),
|
||||
]
|
||||
if image is not None:
|
||||
for i in range(image.shape[0]):
|
||||
content_list.append(
|
||||
self.tensor_to_input_image_content(image[i].unsqueeze(0))
|
||||
)
|
||||
if files is not None:
|
||||
content_list.extend(files)
|
||||
|
||||
return InputMessageContentList(
|
||||
root=content_list,
|
||||
)
|
||||
|
||||
def parse_response_id_from_prompt(self, prompt: str) -> Optional[str]:
|
||||
"""Extract response ID from prompt if it exists."""
|
||||
parsed_id = re.search(STARTING_POINT_ID_PATTERN, prompt)
|
||||
return parsed_id.group(1) if parsed_id else None
|
||||
|
||||
def strip_response_tag_from_prompt(self, prompt: str) -> str:
|
||||
"""Remove the response ID tag from the prompt."""
|
||||
return re.sub(STARTING_POINT_ID_PATTERN, "", prompt.strip())
|
||||
|
||||
def delete_history_after_response_id(
|
||||
self, new_start_id: str, session_id: str
|
||||
) -> None:
|
||||
"""Delete history entries after a specific response ID."""
|
||||
if session_id not in self.history:
|
||||
return
|
||||
|
||||
new_history = []
|
||||
i = 0
|
||||
while (
|
||||
i < len(self.history[session_id])
|
||||
and self.history[session_id][i]["response_id"] != new_start_id
|
||||
):
|
||||
new_history.append(self.history[session_id][i])
|
||||
i += 1
|
||||
|
||||
# Since it's the new starting point (not the response being edited), we include it as well
|
||||
if i < len(self.history[session_id]):
|
||||
new_history.append(self.history[session_id][i])
|
||||
|
||||
self.history[session_id] = new_history
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
persist_context: bool,
|
||||
model: SupportedOpenAIModel,
|
||||
unique_id: Optional[str] = None,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
files: Optional[list[InputFileContent]] = None,
|
||||
advanced_options: Optional[CreateModelResponseProperties] = None,
|
||||
**kwargs,
|
||||
) -> tuple[str]:
|
||||
# Validate inputs
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
session_id = self.get_session_id(persist_context)
|
||||
response_id_override = self.parse_response_id_from_prompt(prompt)
|
||||
if response_id_override:
|
||||
is_starting_from_beginning = response_id_override == "start"
|
||||
if is_starting_from_beginning:
|
||||
self.history[session_id] = []
|
||||
previous_response_id = None
|
||||
else:
|
||||
previous_response_id = response_id_override
|
||||
self.delete_history_after_response_id(response_id_override, session_id)
|
||||
prompt = self.strip_response_tag_from_prompt(prompt)
|
||||
elif persist_context:
|
||||
previous_response_id = self.previous_response_id
|
||||
else:
|
||||
previous_response_id = None
|
||||
|
||||
# Create response
|
||||
create_response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=RESPONSES_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=OpenAICreateResponse,
|
||||
response_model=OpenAIResponse,
|
||||
),
|
||||
request=OpenAICreateResponse(
|
||||
input=[
|
||||
Item(
|
||||
root=InputMessage(
|
||||
content=self.create_input_message_contents(
|
||||
prompt, images, files
|
||||
),
|
||||
role="user",
|
||||
)
|
||||
),
|
||||
],
|
||||
store=True,
|
||||
stream=False,
|
||||
model=model,
|
||||
previous_response_id=previous_response_id,
|
||||
**(
|
||||
advanced_options.model_dump(exclude_none=True)
|
||||
if advanced_options
|
||||
else {}
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
response_id = create_response.id
|
||||
|
||||
# Get result output
|
||||
result_response = self.get_result_response(response_id, auth_kwargs=kwargs)
|
||||
output_text = self.parse_output_text_from_response(result_response)
|
||||
|
||||
# Update history
|
||||
self.add_to_history(session_id, prompt, output_text, response_id)
|
||||
self.display_history_on_node(session_id, unique_id)
|
||||
self.previous_response_id = response_id
|
||||
|
||||
return (output_text,)
|
||||
|
||||
|
||||
class OpenAIInputFiles(ComfyNodeABC):
|
||||
"""
|
||||
Loads and formats input files for OpenAI API.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
"""
|
||||
For details about the supported file input types, see:
|
||||
https://platform.openai.com/docs/guides/pdf-files?api-mode=responses
|
||||
"""
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
input_files = [
|
||||
f
|
||||
for f in os.scandir(input_dir)
|
||||
if f.is_file()
|
||||
and (f.name.endswith(".txt") or f.name.endswith(".pdf"))
|
||||
and f.stat().st_size < 32 * 1024 * 1024
|
||||
]
|
||||
input_files = sorted(input_files, key=lambda x: x.name)
|
||||
input_files = [f.name for f in input_files]
|
||||
return {
|
||||
"required": {
|
||||
"file": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
|
||||
"options": input_files,
|
||||
"default": input_files[0] if input_files else None,
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"OPENAI_INPUT_FILES": (
|
||||
"OPENAI_INPUT_FILES",
|
||||
{
|
||||
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
|
||||
"default": None,
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes."
|
||||
RETURN_TYPES = ("OPENAI_INPUT_FILES",)
|
||||
FUNCTION = "prepare_files"
|
||||
CATEGORY = "api node/text/OpenAI"
|
||||
|
||||
def create_input_file_content(self, file_path: str) -> InputFileContent:
|
||||
return InputFileContent(
|
||||
file_data=text_filepath_to_data_uri(file_path),
|
||||
filename=os.path.basename(file_path),
|
||||
type="input_file",
|
||||
)
|
||||
|
||||
def prepare_files(
|
||||
self, file: str, OPENAI_INPUT_FILES: list[InputFileContent] = []
|
||||
) -> tuple[list[InputFileContent]]:
|
||||
"""
|
||||
Loads and formats input files for OpenAI API.
|
||||
"""
|
||||
file_path = folder_paths.get_annotated_filepath(file)
|
||||
input_file_content = self.create_input_file_content(file_path)
|
||||
files = [input_file_content] + OPENAI_INPUT_FILES
|
||||
return (files,)
|
||||
|
||||
|
||||
class OpenAIChatConfig(ComfyNodeABC):
|
||||
"""Allows setting additional configuration for the OpenAI Chat Node."""
|
||||
|
||||
RETURN_TYPES = ("OPENAI_CHAT_CONFIG",)
|
||||
FUNCTION = "configure"
|
||||
DESCRIPTION = (
|
||||
"Allows specifying advanced configuration options for the OpenAI Chat Nodes."
|
||||
)
|
||||
CATEGORY = "api node/text/OpenAI"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"truncation": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["auto", "disabled"],
|
||||
"default": "auto",
|
||||
"tooltip": "The truncation strategy to use for the model response. auto: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation.disabled: If a model response will exceed the context window size for a model, the request will fail with a 400 error",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"max_output_tokens": model_field_to_node_input(
|
||||
IO.INT,
|
||||
OpenAICreateResponse,
|
||||
"max_output_tokens",
|
||||
min=16,
|
||||
default=4096,
|
||||
max=16384,
|
||||
tooltip="An upper bound for the number of tokens that can be generated for a response, including visible output tokens",
|
||||
),
|
||||
"instructions": model_field_to_node_input(
|
||||
IO.STRING, OpenAICreateResponse, "instructions", multiline=True
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
def configure(
|
||||
self,
|
||||
truncation: bool,
|
||||
instructions: Optional[str] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
) -> tuple[CreateModelResponseProperties]:
|
||||
"""
|
||||
Configure advanced options for the OpenAI Chat Node.
|
||||
|
||||
Note:
|
||||
While `top_p` and `temperature` are listed as properties in the
|
||||
spec, they are not supported for all models (e.g., o4-mini).
|
||||
They are not exposed as inputs at all to avoid having to manually
|
||||
remove depending on model choice.
|
||||
"""
|
||||
return (
|
||||
CreateModelResponseProperties(
|
||||
instructions=instructions,
|
||||
truncation=truncation,
|
||||
max_output_tokens=max_output_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"OpenAIDalle2": OpenAIDalle2,
|
||||
"OpenAIDalle3": OpenAIDalle3,
|
||||
"OpenAIGPTImage1": OpenAIGPTImage1,
|
||||
"OpenAIChatNode": OpenAIChatNode,
|
||||
"OpenAIInputFiles": OpenAIInputFiles,
|
||||
"OpenAIChatConfig": OpenAIChatConfig,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"OpenAIDalle2": "OpenAI DALL·E 2",
|
||||
"OpenAIDalle3": "OpenAI DALL·E 3",
|
||||
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
||||
"OpenAIChatNode": "OpenAI Chat",
|
||||
"OpenAIInputFiles": "OpenAI Chat Input Files",
|
||||
"OpenAIChatConfig": "OpenAI Chat Advanced Options",
|
||||
}
|
||||
|
||||
@@ -3,42 +3,45 @@ Pika x ComfyUI API Nodes
|
||||
|
||||
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from typing import Optional, TypeVar
|
||||
import logging
|
||||
import torch
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_video_output,
|
||||
tensor_to_bytesio,
|
||||
)
|
||||
from comfy_api_nodes.apis import (
|
||||
PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||
PikaGenerateResponse,
|
||||
PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||
PikaVideoResponse,
|
||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
IngredientsMode,
|
||||
PikaDurationEnum,
|
||||
PikaResolutionEnum,
|
||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||
PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
PikaDurationEnum,
|
||||
Pikaffect,
|
||||
PikaGenerateResponse,
|
||||
PikaResolutionEnum,
|
||||
PikaVideoResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
tensor_to_bytesio,
|
||||
download_url_to_video_output,
|
||||
HttpMethod,
|
||||
PollingOperation,
|
||||
SynchronousOperation,
|
||||
)
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
from comfy_api.input_impl.video_types import VideoInput, VideoContainer, VideoCodec
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
@@ -120,7 +123,10 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
|
||||
def poll_for_task_status(
|
||||
self, task_id: str, auth_token: str
|
||||
self,
|
||||
task_id: str,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> PikaGenerateResponse:
|
||||
polling_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
@@ -139,20 +145,26 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
progress_extractor=lambda response: (
|
||||
response.progress if hasattr(response, "progress") else None
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=lambda response: (
|
||||
response.url if hasattr(response, "url") else None
|
||||
),
|
||||
node_id=node_id,
|
||||
estimated_duration=60
|
||||
)
|
||||
return polling_operation.execute()
|
||||
|
||||
def execute_task(
|
||||
self,
|
||||
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
|
||||
auth_token: Optional[str] = None,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> tuple[VideoFromFile]:
|
||||
"""Executes the initial operation then polls for the task status until it is completed.
|
||||
|
||||
Args:
|
||||
initial_operation: The initial operation to execute.
|
||||
auth_token: The authentication token to use for the API call.
|
||||
auth_kwargs: The authentication token(s) to use for the API call.
|
||||
|
||||
Returns:
|
||||
A tuple containing the video file as a VIDEO output.
|
||||
@@ -164,7 +176,7 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
raise PikaApiError(error_msg)
|
||||
|
||||
task_id = initial_response.video_id
|
||||
final_response = self.poll_for_task_status(task_id, auth_token)
|
||||
final_response = self.poll_for_task_status(task_id, auth_kwargs)
|
||||
if not is_valid_video_response(final_response):
|
||||
error_msg = (
|
||||
f"Pika task {task_id} succeeded but no video data found in response."
|
||||
@@ -193,6 +205,8 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -206,7 +220,8 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert image to BytesIO
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
@@ -233,10 +248,10 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_token)
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
@@ -259,6 +274,8 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -272,7 +289,8 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
resolution: str,
|
||||
duration: int,
|
||||
aspect_ratio: float,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -289,11 +307,11 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_token)
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaScenesV2_2(PikaNodeBase):
|
||||
@@ -336,6 +354,8 @@ class PikaScenesV2_2(PikaNodeBase):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -350,12 +370,13 @@ class PikaScenesV2_2(PikaNodeBase):
|
||||
duration: int,
|
||||
ingredients_mode: str,
|
||||
aspect_ratio: float,
|
||||
unique_id: str,
|
||||
image_ingredient_1: Optional[torch.Tensor] = None,
|
||||
image_ingredient_2: Optional[torch.Tensor] = None,
|
||||
image_ingredient_3: Optional[torch.Tensor] = None,
|
||||
image_ingredient_4: Optional[torch.Tensor] = None,
|
||||
image_ingredient_5: Optional[torch.Tensor] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert all passed images to BytesIO
|
||||
all_image_bytes_io = []
|
||||
@@ -396,10 +417,10 @@ class PikaScenesV2_2(PikaNodeBase):
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_token)
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikAdditionsNode(PikaNodeBase):
|
||||
@@ -434,10 +455,12 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you’d like to add to create a seamlessly integrated result."
|
||||
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
@@ -446,7 +469,8 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert video to BytesIO
|
||||
video_bytes_io = io.BytesIO()
|
||||
@@ -479,10 +503,10 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_token)
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaSwapsNode(PikaNodeBase):
|
||||
@@ -526,6 +550,8 @@ class PikaSwapsNode(PikaNodeBase):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -540,7 +566,8 @@ class PikaSwapsNode(PikaNodeBase):
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert video to BytesIO
|
||||
video_bytes_io = io.BytesIO()
|
||||
@@ -583,10 +610,10 @@ class PikaSwapsNode(PikaNodeBase):
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_token)
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaffectsNode(PikaNodeBase):
|
||||
@@ -630,6 +657,8 @@ class PikaffectsNode(PikaNodeBase):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -642,7 +671,8 @@ class PikaffectsNode(PikaNodeBase):
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
@@ -660,10 +690,10 @@ class PikaffectsNode(PikaNodeBase):
|
||||
),
|
||||
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_token)
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
@@ -681,6 +711,8 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -695,7 +727,8 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
auth_token: Optional[str] = None,
|
||||
unique_id: str,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
|
||||
pika_files = [
|
||||
@@ -722,10 +755,10 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
),
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_token)
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from inspect import cleandoc
|
||||
|
||||
from typing import Optional
|
||||
from comfy_api_nodes.apis.pixverse_api import (
|
||||
PixverseTextVideoRequest,
|
||||
PixverseImageVideoRequest,
|
||||
@@ -34,11 +34,22 @@ import requests
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def upload_image_to_pixverse(image: torch.Tensor, auth_token=None):
|
||||
AVERAGE_DURATION_T2V = 32
|
||||
AVERAGE_DURATION_I2V = 30
|
||||
AVERAGE_DURATION_T2T = 52
|
||||
|
||||
|
||||
def get_video_url_from_response(
|
||||
response: PixverseGenerationStatusResponse,
|
||||
) -> Optional[str]:
|
||||
if response.Resp is None or response.Resp.url is None:
|
||||
return None
|
||||
return str(response.Resp.url)
|
||||
|
||||
|
||||
def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
# first, upload image to Pixverse and get image id to use in actual generation call
|
||||
files = {
|
||||
"image": tensor_to_bytesio(image)
|
||||
}
|
||||
files = {"image": tensor_to_bytesio(image)}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/image/upload",
|
||||
@@ -49,12 +60,14 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_token=None):
|
||||
request=EmptyRequest(),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_upload: PixverseImageUploadResponse = operation.execute()
|
||||
|
||||
if response_upload.Resp is None:
|
||||
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
|
||||
raise Exception(
|
||||
f"PixVerse image upload request failed: '{response_upload.ErrMsg}'"
|
||||
)
|
||||
|
||||
return response_upload.Resp.img_id
|
||||
|
||||
@@ -73,7 +86,7 @@ class PixverseTemplateNode:
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"template": (list(pixverse_templates.keys()), ),
|
||||
"template": (list(pixverse_templates.keys()),),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,7 +100,7 @@ class PixverseTemplateNode:
|
||||
|
||||
class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos synchronously based on prompt and output_size.
|
||||
Generates videos based on prompt and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
@@ -108,9 +121,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
"tooltip": "Prompt for the video generation",
|
||||
},
|
||||
),
|
||||
"aspect_ratio": (
|
||||
[ratio.value for ratio in PixverseAspectRatio],
|
||||
),
|
||||
"aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],),
|
||||
"quality": (
|
||||
[resolution.value for resolution in PixverseQuality],
|
||||
{
|
||||
@@ -143,11 +154,13 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
PixverseIO.TEMPLATE,
|
||||
{
|
||||
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
|
||||
}
|
||||
)
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -159,9 +172,9 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
duration_seconds: int,
|
||||
motion_mode: str,
|
||||
seed,
|
||||
negative_prompt: str=None,
|
||||
pixverse_template: int=None,
|
||||
auth_token=None,
|
||||
negative_prompt: str = None,
|
||||
pixverse_template: int = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
@@ -190,7 +203,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
template_id=pixverse_template,
|
||||
seed=seed,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
@@ -205,19 +218,27 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
|
||||
failed_statuses=[
|
||||
PixverseStatus.contents_moderation,
|
||||
PixverseStatus.failed,
|
||||
PixverseStatus.deleted,
|
||||
],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
|
||||
|
||||
class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos synchronously based on prompt and output_size.
|
||||
Generates videos based on prompt and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
@@ -230,9 +251,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
),
|
||||
"image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
@@ -273,11 +292,13 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
PixverseIO.TEMPLATE,
|
||||
{
|
||||
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
|
||||
}
|
||||
)
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -289,13 +310,13 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
duration_seconds: int,
|
||||
motion_mode: str,
|
||||
seed,
|
||||
negative_prompt: str=None,
|
||||
pixverse_template: int=None,
|
||||
auth_token=None,
|
||||
negative_prompt: str = None,
|
||||
pixverse_template: int = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
img_id = upload_image_to_pixverse(image, auth_token=auth_token)
|
||||
img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
@@ -322,7 +343,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
template_id=pixverse_template,
|
||||
seed=seed,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
@@ -337,9 +358,16 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
|
||||
failed_statuses=[
|
||||
PixverseStatus.contents_moderation,
|
||||
PixverseStatus.failed,
|
||||
PixverseStatus.deleted,
|
||||
],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
@@ -349,7 +377,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
|
||||
class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos synchronously based on prompt and output_size.
|
||||
Generates videos based on prompt and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
@@ -362,12 +390,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"first_frame": (
|
||||
IO.IMAGE,
|
||||
),
|
||||
"last_frame": (
|
||||
IO.IMAGE,
|
||||
),
|
||||
"first_frame": (IO.IMAGE,),
|
||||
"last_frame": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
@@ -407,6 +431,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -419,13 +445,13 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
duration_seconds: int,
|
||||
motion_mode: str,
|
||||
seed,
|
||||
negative_prompt: str=None,
|
||||
auth_token=None,
|
||||
negative_prompt: str = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
first_frame_id = upload_image_to_pixverse(first_frame, auth_token=auth_token)
|
||||
last_frame_id = upload_image_to_pixverse(last_frame, auth_token=auth_token)
|
||||
first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
|
||||
last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
@@ -452,7 +478,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
seed=seed,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
@@ -467,9 +493,16 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
|
||||
failed_statuses=[
|
||||
PixverseStatus.contents_moderation,
|
||||
PixverseStatus.failed,
|
||||
PixverseStatus.deleted,
|
||||
],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy_extras.nodes_images import SVG # Added
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api_nodes.apis.recraft_api import (
|
||||
RecraftImageGenerationRequest,
|
||||
@@ -28,9 +30,8 @@ from comfy_api_nodes.apinode_utils import (
|
||||
resize_mask_to_image,
|
||||
validate_string,
|
||||
)
|
||||
import folder_paths
|
||||
import json
|
||||
import os
|
||||
from server import PromptServer
|
||||
|
||||
import torch
|
||||
from io import BytesIO
|
||||
from PIL import UnidentifiedImageError
|
||||
@@ -43,7 +44,7 @@ def handle_recraft_file_request(
|
||||
total_pixels=4096*4096,
|
||||
timeout=1024,
|
||||
request=None,
|
||||
auth_token=None
|
||||
auth_kwargs: dict[str,str] = None,
|
||||
) -> list[BytesIO]:
|
||||
"""
|
||||
Handle sending common Recraft file-only request to get back file bytes.
|
||||
@@ -67,7 +68,7 @@ def handle_recraft_file_request(
|
||||
request=request,
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=auth_kwargs,
|
||||
multipart_parser=recraft_multipart_parser,
|
||||
)
|
||||
response: RecraftImageGenerationResponse = operation.execute()
|
||||
@@ -162,102 +163,6 @@ class handle_recraft_image_output:
|
||||
raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.")
|
||||
|
||||
|
||||
class SVG:
|
||||
"""
|
||||
Stores SVG representations via a list of BytesIO objects.
|
||||
"""
|
||||
def __init__(self, data: list[BytesIO]):
|
||||
self.data = data
|
||||
|
||||
def combine(self, other: SVG):
|
||||
return SVG(self.data + other.data)
|
||||
|
||||
@staticmethod
|
||||
def combine_all(svgs: list[SVG]):
|
||||
all_svgs = []
|
||||
for svg in svgs:
|
||||
all_svgs.extend(svg.data)
|
||||
return SVG(all_svgs)
|
||||
|
||||
|
||||
class SaveSVGNode:
|
||||
"""
|
||||
Save SVG files on disk.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
RETURN_TYPES = ()
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "save_svg"
|
||||
CATEGORY = "api node/image/Recraft"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"svg": (RecraftIO.SVG,),
|
||||
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO"
|
||||
}
|
||||
}
|
||||
|
||||
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results = list()
|
||||
|
||||
# Prepare metadata JSON
|
||||
metadata_dict = {}
|
||||
if prompt is not None:
|
||||
metadata_dict["prompt"] = prompt
|
||||
if extra_pnginfo is not None:
|
||||
metadata_dict.update(extra_pnginfo)
|
||||
|
||||
# Convert metadata to JSON string
|
||||
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
|
||||
|
||||
for batch_number, svg_bytes in enumerate(svg.data):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.svg"
|
||||
|
||||
# Read SVG content
|
||||
svg_bytes.seek(0)
|
||||
svg_content = svg_bytes.read().decode('utf-8')
|
||||
|
||||
# Inject metadata if available
|
||||
if metadata_json:
|
||||
# Create metadata element with CDATA section
|
||||
metadata_element = f""" <metadata>
|
||||
<![CDATA[
|
||||
{metadata_json}
|
||||
]]>
|
||||
</metadata>
|
||||
"""
|
||||
# Insert metadata after opening svg tag using regex
|
||||
import re
|
||||
svg_content = re.sub(r'(<svg[^>]*>)', r'\1\n' + metadata_element, svg_content)
|
||||
|
||||
# Write the modified SVG to file
|
||||
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
|
||||
svg_file.write(svg_content.encode('utf-8'))
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
return { "ui": { "images": results } }
|
||||
|
||||
|
||||
class RecraftColorRGBNode:
|
||||
"""
|
||||
Create Recraft Color by choosing specific RGB values.
|
||||
@@ -485,6 +390,8 @@ class RecraftTextToImageNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -497,7 +404,7 @@ class RecraftTextToImageNode:
|
||||
recraft_style: RecraftStyle = None,
|
||||
negative_prompt: str = None,
|
||||
recraft_controls: RecraftControls = None,
|
||||
auth_token=None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False, max_length=1000)
|
||||
@@ -530,12 +437,19 @@ class RecraftTextToImageNode:
|
||||
style_id=recraft_style.style_id,
|
||||
controls=controls_api,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response: RecraftImageGenerationResponse = operation.execute()
|
||||
images = []
|
||||
urls = []
|
||||
for data in response.data:
|
||||
with handle_recraft_image_output():
|
||||
if unique_id and data.url:
|
||||
urls.append(data.url)
|
||||
urls_string = '\n'.join(urls)
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {urls_string}", unique_id
|
||||
)
|
||||
image = bytesio_to_image_tensor(
|
||||
download_url_to_bytesio(data.url, timeout=1024)
|
||||
)
|
||||
@@ -620,6 +534,7 @@ class RecraftImageToImageNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -630,7 +545,6 @@ class RecraftImageToImageNode:
|
||||
n: int,
|
||||
strength: float,
|
||||
seed,
|
||||
auth_token=None,
|
||||
recraft_style: RecraftStyle = None,
|
||||
negative_prompt: str = None,
|
||||
recraft_controls: RecraftControls = None,
|
||||
@@ -668,7 +582,7 @@ class RecraftImageToImageNode:
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/imageToImage",
|
||||
request=request,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
with handle_recraft_image_output():
|
||||
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
|
||||
@@ -736,6 +650,7 @@ class RecraftImageInpaintingNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -746,7 +661,6 @@ class RecraftImageInpaintingNode:
|
||||
prompt: str,
|
||||
n: int,
|
||||
seed,
|
||||
auth_token=None,
|
||||
recraft_style: RecraftStyle = None,
|
||||
negative_prompt: str = None,
|
||||
**kwargs,
|
||||
@@ -781,7 +695,7 @@ class RecraftImageInpaintingNode:
|
||||
mask=mask[i:i+1],
|
||||
path="/proxy/recraft/images/inpaint",
|
||||
request=request,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
with handle_recraft_image_output():
|
||||
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
|
||||
@@ -796,8 +710,8 @@ class RecraftTextToVectorNode:
|
||||
Generates SVG synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (RecraftIO.SVG,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
RETURN_TYPES = ("SVG",) # Changed
|
||||
DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Recraft"
|
||||
@@ -860,6 +774,8 @@ class RecraftTextToVectorNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -872,7 +788,7 @@ class RecraftTextToVectorNode:
|
||||
seed,
|
||||
negative_prompt: str = None,
|
||||
recraft_controls: RecraftControls = None,
|
||||
auth_token=None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False, max_length=1000)
|
||||
@@ -903,11 +819,18 @@ class RecraftTextToVectorNode:
|
||||
substyle=recraft_style.substyle,
|
||||
controls=controls_api,
|
||||
),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response: RecraftImageGenerationResponse = operation.execute()
|
||||
svg_data = []
|
||||
urls = []
|
||||
for data in response.data:
|
||||
if unique_id and data.url:
|
||||
urls.append(data.url)
|
||||
# Print result on each iteration in case of error
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {' '.join(urls)}", unique_id
|
||||
)
|
||||
svg_data.append(download_url_to_bytesio(data.url, timeout=1024))
|
||||
|
||||
return (SVG(svg_data),)
|
||||
@@ -918,8 +841,8 @@ class RecraftVectorizeImageNode:
|
||||
Generates SVG synchronously from an input image.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (RecraftIO.SVG,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
RETURN_TYPES = ("SVG",) # Changed
|
||||
DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Recraft"
|
||||
@@ -934,13 +857,13 @@ class RecraftVectorizeImageNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
auth_token=None,
|
||||
**kwargs,
|
||||
):
|
||||
svgs = []
|
||||
@@ -950,7 +873,7 @@ class RecraftVectorizeImageNode:
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/vectorize",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
svgs.append(SVG(sub_bytes))
|
||||
pbar.update(1)
|
||||
@@ -1015,6 +938,7 @@ class RecraftReplaceBackgroundNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1024,7 +948,6 @@ class RecraftReplaceBackgroundNode:
|
||||
prompt: str,
|
||||
n: int,
|
||||
seed,
|
||||
auth_token=None,
|
||||
recraft_style: RecraftStyle = None,
|
||||
negative_prompt: str = None,
|
||||
**kwargs,
|
||||
@@ -1054,7 +977,7 @@ class RecraftReplaceBackgroundNode:
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/replaceBackground",
|
||||
request=request,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
|
||||
pbar.update(1)
|
||||
@@ -1084,13 +1007,13 @@ class RecraftRemoveBackgroundNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
auth_token=None,
|
||||
**kwargs,
|
||||
):
|
||||
images = []
|
||||
@@ -1100,7 +1023,7 @@ class RecraftRemoveBackgroundNode:
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/removeBackground",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
|
||||
pbar.update(1)
|
||||
@@ -1135,13 +1058,13 @@ class RecraftCrispUpscaleNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
auth_token=None,
|
||||
**kwargs,
|
||||
):
|
||||
images = []
|
||||
@@ -1151,7 +1074,7 @@ class RecraftCrispUpscaleNode:
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path=self.RECRAFT_PATH,
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
|
||||
pbar.update(1)
|
||||
@@ -1193,7 +1116,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"RecraftStyleV3InfiniteStyleLibrary": RecraftStyleInfiniteStyleLibrary,
|
||||
"RecraftColorRGB": RecraftColorRGBNode,
|
||||
"RecraftControls": RecraftControlsNode,
|
||||
"SaveSVG": SaveSVGNode,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
@@ -1213,5 +1135,4 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"RecraftStyleV3InfiniteStyleLibrary": "Recraft Style - Infinite Style Library",
|
||||
"RecraftColorRGB": "Recraft Color RGB",
|
||||
"RecraftControls": "Recraft Controls",
|
||||
"SaveSVG": "Save SVG",
|
||||
}
|
||||
|
||||
462
comfy_api_nodes/nodes_rodin.py
Normal file
462
comfy_api_nodes/nodes_rodin.py
Normal file
@@ -0,0 +1,462 @@
|
||||
"""
|
||||
ComfyUI X Rodin3D(Deemos) API Nodes
|
||||
|
||||
Rodin API docs: https://developer.hyper3d.ai/
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
import folder_paths as comfy_paths
|
||||
import requests
|
||||
import os
|
||||
import datetime
|
||||
import shutil
|
||||
import time
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
from PIL import Image
|
||||
from comfy_api_nodes.apis.rodin_api import (
|
||||
Rodin3DGenerateRequest,
|
||||
Rodin3DGenerateResponse,
|
||||
Rodin3DCheckStatusRequest,
|
||||
Rodin3DCheckStatusResponse,
|
||||
Rodin3DDownloadRequest,
|
||||
Rodin3DDownloadResponse,
|
||||
JobStatus,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
)
|
||||
|
||||
|
||||
COMMON_PARAMETERS = {
|
||||
"Seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default":0,
|
||||
"min":0,
|
||||
"max":65535,
|
||||
"display":"number"
|
||||
}
|
||||
),
|
||||
"Material_Type": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["PBR", "Shaded"],
|
||||
"default": "PBR"
|
||||
}
|
||||
),
|
||||
"Polygon_count": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
|
||||
"default": "18K-Quad"
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
def create_task_error(response: Rodin3DGenerateResponse):
|
||||
"""Check if the response has error"""
|
||||
return hasattr(response, "error")
|
||||
|
||||
|
||||
|
||||
class Rodin3DAPI:
|
||||
"""
|
||||
Generate 3D Assets using Rodin API
|
||||
"""
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
RETURN_NAMES = ("3D Model Path",)
|
||||
CATEGORY = "api node/3d/Rodin"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
|
||||
def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048):
|
||||
"""
|
||||
Converts a PyTorch tensor to a file-like object.
|
||||
|
||||
Args:
|
||||
- tensor (torch.Tensor): A tensor representing an image of shape (H, W, C)
|
||||
where C is the number of channels (3 for RGB), H is height, and W is width.
|
||||
|
||||
Returns:
|
||||
- io.BytesIO: A file-like object containing the image data.
|
||||
"""
|
||||
array = tensor.cpu().numpy()
|
||||
array = (array * 255).astype('uint8')
|
||||
image = Image.fromarray(array, 'RGB')
|
||||
|
||||
original_width, original_height = image.size
|
||||
original_pixels = original_width * original_height
|
||||
if original_pixels > max_pixels:
|
||||
scale = math.sqrt(max_pixels / original_pixels)
|
||||
new_width = int(original_width * scale)
|
||||
new_height = int(original_height * scale)
|
||||
else:
|
||||
new_width, new_height = original_width, original_height
|
||||
|
||||
if new_width != original_width or new_height != original_height:
|
||||
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
|
||||
img_byte_arr.seek(0)
|
||||
return img_byte_arr
|
||||
|
||||
def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str:
|
||||
has_failed = any(job.status == JobStatus.Failed for job in response.jobs)
|
||||
all_done = all(job.status == JobStatus.Done for job in response.jobs)
|
||||
status_list = [str(job.status) for job in response.jobs]
|
||||
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
|
||||
if has_failed:
|
||||
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
|
||||
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
|
||||
elif all_done:
|
||||
return "DONE"
|
||||
else:
|
||||
return "Generating"
|
||||
|
||||
def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
|
||||
if images == None:
|
||||
raise Exception("Rodin 3D generate requires at least 1 image.")
|
||||
if len(images) >= 5:
|
||||
raise Exception("Rodin 3D generate requires up to 5 image.")
|
||||
|
||||
path = "/proxy/rodin/api/v2/rodin"
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=Rodin3DGenerateRequest,
|
||||
response_model=Rodin3DGenerateResponse,
|
||||
),
|
||||
request=Rodin3DGenerateRequest(
|
||||
seed=seed,
|
||||
tier=tier,
|
||||
material=material,
|
||||
quality=quality,
|
||||
mesh_mode=mesh_mode
|
||||
),
|
||||
files=[
|
||||
(
|
||||
"images",
|
||||
open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image)
|
||||
)
|
||||
for image in images if image is not None
|
||||
],
|
||||
content_type = "multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
if create_task_error(response):
|
||||
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
|
||||
logging.error(error_message)
|
||||
raise Exception(error_message)
|
||||
|
||||
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
|
||||
subscription_key = response.jobs.subscription_key
|
||||
task_uuid = response.uuid
|
||||
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
|
||||
return task_uuid, subscription_key
|
||||
|
||||
def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
|
||||
|
||||
path = "/proxy/rodin/api/v2/status"
|
||||
|
||||
poll_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path = path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=Rodin3DCheckStatusRequest,
|
||||
response_model=Rodin3DCheckStatusResponse,
|
||||
),
|
||||
request=Rodin3DCheckStatusRequest(
|
||||
subscription_key = subscription_key
|
||||
),
|
||||
completed_statuses=["DONE"],
|
||||
failed_statuses=["FAILED"],
|
||||
status_extractor=self.check_rodin_status,
|
||||
poll_interval=3.0,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
|
||||
|
||||
return poll_operation.execute()
|
||||
|
||||
|
||||
|
||||
def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
|
||||
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
||||
|
||||
path = "/proxy/rodin/api/v2/download"
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=Rodin3DDownloadRequest,
|
||||
response_model=Rodin3DDownloadResponse,
|
||||
),
|
||||
request=Rodin3DDownloadRequest(
|
||||
task_uuid=uuid
|
||||
),
|
||||
auth_kwargs=kwargs
|
||||
)
|
||||
|
||||
return operation.execute()
|
||||
|
||||
def GetQualityAndMode(self, PolyCount):
|
||||
if PolyCount == "200K-Triangle":
|
||||
mesh_mode = "Raw"
|
||||
quality = "medium"
|
||||
else:
|
||||
mesh_mode = "Quad"
|
||||
if PolyCount == "4K-Quad":
|
||||
quality = "extra-low"
|
||||
elif PolyCount == "8K-Quad":
|
||||
quality = "low"
|
||||
elif PolyCount == "18K-Quad":
|
||||
quality = "medium"
|
||||
elif PolyCount == "50K-Quad":
|
||||
quality = "high"
|
||||
else:
|
||||
quality = "medium"
|
||||
|
||||
return mesh_mode, quality
|
||||
|
||||
def DownLoadFiles(self, Url_List):
|
||||
Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
os.makedirs(Save_path, exist_ok=True)
|
||||
model_file_path = None
|
||||
for Item in Url_List.list:
|
||||
url = Item.url
|
||||
file_name = Item.name
|
||||
file_path = os.path.join(Save_path, file_name)
|
||||
if file_path.endswith(".glb"):
|
||||
model_file_path = file_path
|
||||
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
|
||||
if attempt < max_retries - 1:
|
||||
logging.info("Retrying...")
|
||||
time.sleep(2)
|
||||
else:
|
||||
logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.")
|
||||
|
||||
return model_file_path
|
||||
|
||||
|
||||
class Rodin3D_Regular(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"Images":
|
||||
(
|
||||
IO.IMAGE,
|
||||
{
|
||||
"forceInput":True,
|
||||
}
|
||||
)
|
||||
},
|
||||
"optional": {
|
||||
**COMMON_PARAMETERS
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
Material_Type,
|
||||
Polygon_count,
|
||||
**kwargs
|
||||
):
|
||||
tier = "Regular"
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
|
||||
return (model,)
|
||||
|
||||
class Rodin3D_Detail(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"Images":
|
||||
(
|
||||
IO.IMAGE,
|
||||
{
|
||||
"forceInput":True,
|
||||
}
|
||||
)
|
||||
},
|
||||
"optional": {
|
||||
**COMMON_PARAMETERS
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
Material_Type,
|
||||
Polygon_count,
|
||||
**kwargs
|
||||
):
|
||||
tier = "Detail"
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
|
||||
return (model,)
|
||||
|
||||
class Rodin3D_Smooth(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"Images":
|
||||
(
|
||||
IO.IMAGE,
|
||||
{
|
||||
"forceInput":True,
|
||||
}
|
||||
)
|
||||
},
|
||||
"optional": {
|
||||
**COMMON_PARAMETERS
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
Material_Type,
|
||||
Polygon_count,
|
||||
**kwargs
|
||||
):
|
||||
tier = "Smooth"
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
|
||||
return (model,)
|
||||
|
||||
class Rodin3D_Sketch(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"Images":
|
||||
(
|
||||
IO.IMAGE,
|
||||
{
|
||||
"forceInput":True,
|
||||
}
|
||||
)
|
||||
},
|
||||
"optional": {
|
||||
"Seed":
|
||||
(
|
||||
IO.INT,
|
||||
{
|
||||
"default":0,
|
||||
"min":0,
|
||||
"max":65535,
|
||||
"display":"number"
|
||||
}
|
||||
)
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
**kwargs
|
||||
):
|
||||
tier = "Sketch"
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
material_type = "PBR"
|
||||
quality = "medium"
|
||||
mesh_mode = "Quad"
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
|
||||
return (model,)
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Rodin3D_Regular": Rodin3D_Regular,
|
||||
"Rodin3D_Detail": Rodin3D_Detail,
|
||||
"Rodin3D_Smooth": Rodin3D_Smooth,
|
||||
"Rodin3D_Sketch": Rodin3D_Sketch,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Rodin3D_Regular": "Rodin 3D Generate - Regular Generate",
|
||||
"Rodin3D_Detail": "Rodin 3D Generate - Detail Generate",
|
||||
"Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate",
|
||||
"Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate",
|
||||
}
|
||||
635
comfy_api_nodes/nodes_runway.py
Normal file
635
comfy_api_nodes/nodes_runway.py
Normal file
@@ -0,0 +1,635 @@
|
||||
"""Runway API Nodes
|
||||
|
||||
API Docs:
|
||||
- https://docs.dev.runwayml.com/api/#tag/Task-management/paths/~1v1~1tasks~1%7Bid%7D/delete
|
||||
|
||||
User Guides:
|
||||
- https://help.runwayml.com/hc/en-us/sections/30265301423635-Gen-3-Alpha
|
||||
- https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video
|
||||
- https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo
|
||||
- https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3
|
||||
|
||||
"""
|
||||
|
||||
from typing import Union, Optional, Any
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api_nodes.apis import (
|
||||
RunwayImageToVideoRequest,
|
||||
RunwayImageToVideoResponse,
|
||||
RunwayTaskStatusResponse as TaskStatusResponse,
|
||||
RunwayTaskStatusEnum as TaskStatus,
|
||||
RunwayModelEnum as Model,
|
||||
RunwayDurationEnum as Duration,
|
||||
RunwayAspectRatioEnum as AspectRatio,
|
||||
RunwayPromptImageObject,
|
||||
RunwayPromptImageDetailedObject,
|
||||
RunwayTextToImageRequest,
|
||||
RunwayTextToImageResponse,
|
||||
Model4,
|
||||
ReferenceImage,
|
||||
RunwayTextToImageAspectRatioEnum,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
upload_images_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
image_tensor_pair_to_batch,
|
||||
validate_string,
|
||||
download_url_to_image_tensor,
|
||||
)
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
|
||||
|
||||
AVERAGE_DURATION_I2V_SECONDS = 64
|
||||
AVERAGE_DURATION_FLF_SECONDS = 256
|
||||
AVERAGE_DURATION_T2I_SECONDS = 41
|
||||
|
||||
|
||||
class RunwayApiError(Exception):
|
||||
"""Base exception for Runway API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RunwayGen4TurboAspectRatio(str, Enum):
|
||||
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
|
||||
|
||||
field_1280_720 = "1280:720"
|
||||
field_720_1280 = "720:1280"
|
||||
field_1104_832 = "1104:832"
|
||||
field_832_1104 = "832:1104"
|
||||
field_960_960 = "960:960"
|
||||
field_1584_672 = "1584:672"
|
||||
|
||||
|
||||
class RunwayGen3aAspectRatio(str, Enum):
|
||||
"""Aspect ratios supported for Image to Video API when using gen3a_turbo model."""
|
||||
|
||||
field_768_1280 = "768:1280"
|
||||
field_1280_768 = "1280:768"
|
||||
|
||||
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
"""Returns the video URL from the task status response if it exists."""
|
||||
if response.output and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
return None
|
||||
|
||||
|
||||
# TODO: replace with updated image validation utils (upstream)
|
||||
def validate_input_image(image: torch.Tensor) -> bool:
|
||||
"""
|
||||
Validate the input image is within the size limits for the Runway API.
|
||||
See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons
|
||||
"""
|
||||
return image.shape[2] < 8000 and image.shape[1] < 8000
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> TaskStatusResponse:
|
||||
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
TaskStatus.SUCCEEDED.value,
|
||||
],
|
||||
failed_statuses=[
|
||||
TaskStatus.FAILED.value,
|
||||
TaskStatus.CANCELLED.value,
|
||||
],
|
||||
status_extractor=lambda response: (response.status.value),
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=get_video_url_from_task_status,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
progress_extractor=extract_progress_from_task_status,
|
||||
).execute()
|
||||
|
||||
|
||||
def extract_progress_from_task_status(
|
||||
response: TaskStatusResponse,
|
||||
) -> Union[float, None]:
|
||||
if hasattr(response, "progress") and response.progress is not None:
|
||||
return response.progress * 100
|
||||
return None
|
||||
|
||||
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
"""Returns the image URL from the task status response if it exists."""
|
||||
if response.output and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
return None
|
||||
|
||||
|
||||
class RunwayVideoGenNode(ComfyNodeABC):
|
||||
"""Runway Video Node Base."""
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/video/Runway"
|
||||
API_NODE = True
|
||||
|
||||
def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool:
|
||||
"""
|
||||
Validate the task creation response from the Runway API matches
|
||||
expected format.
|
||||
"""
|
||||
if not bool(response.id):
|
||||
raise RunwayApiError("Invalid initial response from Runway API.")
|
||||
return True
|
||||
|
||||
def validate_response(self, response: RunwayImageToVideoResponse) -> bool:
|
||||
"""
|
||||
Validate the successful task status response from the Runway API
|
||||
matches expected format.
|
||||
"""
|
||||
if not response.output or len(response.output) == 0:
|
||||
raise RunwayApiError(
|
||||
"Runway task succeeded but no video data found in response."
|
||||
)
|
||||
return True
|
||||
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> RunwayImageToVideoResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def generate_video(
|
||||
self,
|
||||
request: RunwayImageToVideoRequest,
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: Optional[str] = None,
|
||||
) -> tuple[VideoFromFile]:
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=RunwayImageToVideoRequest,
|
||||
response_model=RunwayImageToVideoResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
self.validate_task_created(initial_response)
|
||||
task_id = initial_response.id
|
||||
|
||||
final_response = self.get_response(task_id, auth_kwargs, node_id)
|
||||
self.validate_response(final_response)
|
||||
|
||||
video_url = get_video_url_from_task_status(final_response)
|
||||
return (download_url_to_video_output(video_url),)
|
||||
|
||||
|
||||
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
||||
"""Runway Image to Video Node using Gen3a Turbo model."""
|
||||
|
||||
DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo."
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
|
||||
),
|
||||
"start_frame": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Start frame to be used for the video"},
|
||||
),
|
||||
"duration": model_field_to_node_input(
|
||||
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayImageToVideoRequest,
|
||||
"ratio",
|
||||
enum_type=RunwayGen3aAspectRatio,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
RunwayImageToVideoRequest,
|
||||
"seed",
|
||||
control_after_generate=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Validate inputs
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_input_image(start_frame)
|
||||
|
||||
# Upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen3a_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
)
|
||||
|
||||
|
||||
class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
|
||||
"""Runway Image to Video Node using Gen4 Turbo model."""
|
||||
|
||||
DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video."
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
|
||||
),
|
||||
"start_frame": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Start frame to be used for the video"},
|
||||
),
|
||||
"duration": model_field_to_node_input(
|
||||
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayImageToVideoRequest,
|
||||
"ratio",
|
||||
enum_type=RunwayGen4TurboAspectRatio,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
RunwayImageToVideoRequest,
|
||||
"seed",
|
||||
control_after_generate=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Validate inputs
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_input_image(start_frame)
|
||||
|
||||
# Upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen4_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
)
|
||||
|
||||
|
||||
class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
||||
"""Runway First-Last Frame Node."""
|
||||
|
||||
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
|
||||
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> RunwayImageToVideoResponse:
|
||||
return poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
|
||||
),
|
||||
"start_frame": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Start frame to be used for the video"},
|
||||
),
|
||||
"end_frame": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "End frame to be used for the video. Supported for gen3a_turbo only."
|
||||
},
|
||||
),
|
||||
"duration": model_field_to_node_input(
|
||||
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayImageToVideoRequest,
|
||||
"ratio",
|
||||
enum_type=RunwayGen3aAspectRatio,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
RunwayImageToVideoRequest,
|
||||
"seed",
|
||||
control_after_generate=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
end_frame: torch.Tensor,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Validate inputs
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_input_image(start_frame)
|
||||
validate_input_image(end_frame)
|
||||
|
||||
# Upload images
|
||||
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
stacked_input_images,
|
||||
max_images=2,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
if len(download_urls) != 2:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
model=Model("gen3a_turbo"),
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
),
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[1]), position="last"
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
node_id=unique_id,
|
||||
)
|
||||
|
||||
|
||||
class RunwayTextToImageNode(ComfyNodeABC):
|
||||
"""Runway Text to Image Node."""
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Runway"
|
||||
API_NODE = True
|
||||
DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation."
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True
|
||||
),
|
||||
"ratio": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
RunwayTextToImageRequest,
|
||||
"ratio",
|
||||
enum_type=RunwayTextToImageAspectRatioEnum,
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"reference_image": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Optional reference image to guide the generation"},
|
||||
)
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
def validate_task_created(self, response: RunwayTextToImageResponse) -> bool:
|
||||
"""
|
||||
Validate the task creation response from the Runway API matches
|
||||
expected format.
|
||||
"""
|
||||
if not bool(response.id):
|
||||
raise RunwayApiError("Invalid initial response from Runway API.")
|
||||
return True
|
||||
|
||||
def validate_response(self, response: TaskStatusResponse) -> bool:
|
||||
"""
|
||||
Validate the successful task status response from the Runway API
|
||||
matches expected format.
|
||||
"""
|
||||
if not response.output or len(response.output) == 0:
|
||||
raise RunwayApiError(
|
||||
"Runway task succeeded but no image data found in response."
|
||||
)
|
||||
return True
|
||||
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
ratio: str,
|
||||
reference_image: Optional[torch.Tensor] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# Validate inputs
|
||||
validate_string(prompt, min_length=1)
|
||||
|
||||
# Prepare reference images if provided
|
||||
reference_images = None
|
||||
if reference_image is not None:
|
||||
validate_input_image(reference_image)
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
reference_image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload reference image to comfy api.")
|
||||
|
||||
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
|
||||
|
||||
# Create request
|
||||
request = RunwayTextToImageRequest(
|
||||
promptText=prompt,
|
||||
model=Model4.gen4_image,
|
||||
ratio=ratio,
|
||||
referenceImages=reference_images,
|
||||
)
|
||||
|
||||
# Execute initial request
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_TEXT_TO_IMAGE,
|
||||
method=HttpMethod.POST,
|
||||
request_model=RunwayTextToImageRequest,
|
||||
response_model=RunwayTextToImageResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
self.validate_task_created(initial_response)
|
||||
task_id = initial_response.id
|
||||
|
||||
# Poll for completion
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
self.validate_response(final_response)
|
||||
|
||||
# Download and return image
|
||||
image_url = get_image_url_from_task_status(final_response)
|
||||
return (download_url_to_image_tensor(image_url),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"RunwayFirstLastFrameNode": RunwayFirstLastFrameNode,
|
||||
"RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a,
|
||||
"RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4,
|
||||
"RunwayTextToImageNode": RunwayTextToImageNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video",
|
||||
"RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)",
|
||||
"RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)",
|
||||
"RunwayTextToImageNode": "Runway Text to Image",
|
||||
}
|
||||
@@ -120,12 +120,13 @@ class StabilityStableImageUltraNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
|
||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||
auth_token=None):
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
@@ -160,7 +161,7 @@ class StabilityStableImageUltraNode:
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
@@ -252,12 +253,13 @@ class StabilityStableImageSD_3_5Node:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
|
||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||
auth_token=None):
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
@@ -298,7 +300,7 @@ class StabilityStableImageSD_3_5Node:
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
@@ -368,11 +370,12 @@ class StabilityUpscaleConservativeNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
|
||||
auth_token=None):
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
@@ -398,7 +401,7 @@ class StabilityUpscaleConservativeNode:
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
@@ -473,11 +476,12 @@ class StabilityUpscaleCreativeNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
|
||||
auth_token=None):
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
@@ -506,7 +510,7 @@ class StabilityUpscaleCreativeNode:
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
@@ -521,7 +525,7 @@ class StabilityUpscaleCreativeNode:
|
||||
completed_statuses=[StabilityPollStatus.finished],
|
||||
failed_statuses=[StabilityPollStatus.failed],
|
||||
status_extractor=lambda x: get_async_dummy_status(x),
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll: StabilityResultsGetResponse = operation.execute()
|
||||
|
||||
@@ -555,11 +559,12 @@ class StabilityUpscaleFastNode:
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor,
|
||||
auth_token=None):
|
||||
**kwargs):
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
||||
|
||||
files = {
|
||||
@@ -576,7 +581,7 @@ class StabilityUpscaleFastNode:
|
||||
request=EmptyRequest(),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_token=auth_token,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
|
||||
574
comfy_api_nodes/nodes_tripo.py
Normal file
574
comfy_api_nodes/nodes_tripo.py
Normal file
@@ -0,0 +1,574 @@
|
||||
import os
|
||||
from folder_paths import get_output_directory
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api_nodes.apis import (
|
||||
TripoOrientation,
|
||||
TripoModelVersion,
|
||||
)
|
||||
from comfy_api_nodes.apis.tripo_api import (
|
||||
TripoTaskType,
|
||||
TripoStyle,
|
||||
TripoFileReference,
|
||||
TripoFileEmptyReference,
|
||||
TripoUrlReference,
|
||||
TripoTaskResponse,
|
||||
TripoTaskStatus,
|
||||
TripoTextToModelRequest,
|
||||
TripoImageToModelRequest,
|
||||
TripoMultiviewToModelRequest,
|
||||
TripoTextureModelRequest,
|
||||
TripoRefineModelRequest,
|
||||
TripoAnimateRigRequest,
|
||||
TripoAnimateRetargetRequest,
|
||||
TripoConvertModelRequest,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
upload_images_to_comfyapi,
|
||||
download_url_to_bytesio,
|
||||
)
|
||||
|
||||
|
||||
def upload_image_to_tripo(image, **kwargs):
|
||||
urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
|
||||
return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg"))
|
||||
|
||||
def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||
if response.data is not None:
|
||||
for key in ["pbr_model", "model", "base_model"]:
|
||||
if getattr(response.data.output, key, None) is not None:
|
||||
return getattr(response.data.output, key)
|
||||
raise RuntimeError(f"Failed to get model url from response: {response}")
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
kwargs: dict[str, str],
|
||||
response: TripoTaskResponse,
|
||||
) -> tuple[str, str]:
|
||||
"""Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
if response.code != 0:
|
||||
raise RuntimeError(f"Failed to generate mesh: {response.error}")
|
||||
task_id = response.data.task_id
|
||||
response_poll = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/tripo/v2/openapi/task/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
completed_statuses=[TripoTaskStatus.SUCCESS],
|
||||
failed_statuses=[
|
||||
TripoTaskStatus.FAILED,
|
||||
TripoTaskStatus.CANCELLED,
|
||||
TripoTaskStatus.UNKNOWN,
|
||||
TripoTaskStatus.BANNED,
|
||||
TripoTaskStatus.EXPIRED,
|
||||
],
|
||||
status_extractor=lambda x: x.data.status,
|
||||
auth_kwargs=kwargs,
|
||||
node_id=kwargs["unique_id"],
|
||||
result_url_extractor=get_model_url_from_response,
|
||||
progress_extractor=lambda x: x.data.progress,
|
||||
).execute()
|
||||
if response_poll.data.status == TripoTaskStatus.SUCCESS:
|
||||
url = get_model_url_from_response(response_poll)
|
||||
bytesio = download_url_to_bytesio(url)
|
||||
# Save the downloaded model file
|
||||
model_file = f"tripo_model_{task_id}.glb"
|
||||
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
|
||||
f.write(bytesio.getvalue())
|
||||
return model_file, task_id
|
||||
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
|
||||
|
||||
class TripoTextToModelNode:
|
||||
"""
|
||||
Generates 3D models synchronously based on a text prompt using Tripo's API.
|
||||
"""
|
||||
AVERAGE_DURATION = 80
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": ("STRING", {"multiline": True}),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": ("STRING", {"multiline": True}),
|
||||
"model_version": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "model_version", enum_type=TripoModelVersion),
|
||||
"style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"),
|
||||
"texture": ("BOOLEAN", {"default": True}),
|
||||
"pbr": ("BOOLEAN", {"default": True}),
|
||||
"image_seed": ("INT", {"default": 42}),
|
||||
"model_seed": ("INT", {"default": 42}),
|
||||
"texture_seed": ("INT", {"default": 42}),
|
||||
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
|
||||
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
|
||||
"quad": ("BOOLEAN", {"default": False})
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
|
||||
RETURN_NAMES = ("model_file", "model task_id")
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
style_enum = None if style == "None" else style
|
||||
if not prompt:
|
||||
raise RuntimeError("Prompt is required")
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoTextToModelRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoTextToModelRequest(
|
||||
type=TripoTaskType.TEXT_TO_MODEL,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
model_version=model_version,
|
||||
style=style_enum,
|
||||
texture=texture,
|
||||
pbr=pbr,
|
||||
image_seed=image_seed,
|
||||
model_seed=model_seed,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
face_limit=face_limit,
|
||||
auto_size=True,
|
||||
quad=quad
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
class TripoImageToModelNode:
|
||||
"""
|
||||
Generates 3D models synchronously based on a single image using Tripo's API.
|
||||
"""
|
||||
AVERAGE_DURATION = 80
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
"optional": {
|
||||
"model_version": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "model_version", enum_type=TripoModelVersion),
|
||||
"style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"),
|
||||
"texture": ("BOOLEAN", {"default": True}),
|
||||
"pbr": ("BOOLEAN", {"default": True}),
|
||||
"model_seed": ("INT", {"default": 42}),
|
||||
"orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation),
|
||||
"texture_seed": ("INT", {"default": 42}),
|
||||
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
|
||||
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
|
||||
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
|
||||
"quad": ("BOOLEAN", {"default": False})
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
|
||||
RETURN_NAMES = ("model_file", "model task_id")
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
style_enum = None if style == "None" else style
|
||||
if image is None:
|
||||
raise RuntimeError("Image is required")
|
||||
tripo_file = upload_image_to_tripo(image, **kwargs)
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoImageToModelRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoImageToModelRequest(
|
||||
type=TripoTaskType.IMAGE_TO_MODEL,
|
||||
file=tripo_file,
|
||||
model_version=model_version,
|
||||
style=style_enum,
|
||||
texture=texture,
|
||||
pbr=pbr,
|
||||
model_seed=model_seed,
|
||||
orientation=orientation,
|
||||
texture_alignment=texture_alignment,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
face_limit=face_limit,
|
||||
auto_size=True,
|
||||
quad=quad
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
class TripoMultiviewToModelNode:
|
||||
"""
|
||||
Generates 3D models synchronously based on up to four images (front, left, back, right) using Tripo's API.
|
||||
"""
|
||||
AVERAGE_DURATION = 80
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
"optional": {
|
||||
"image_left": ("IMAGE",),
|
||||
"image_back": ("IMAGE",),
|
||||
"image_right": ("IMAGE",),
|
||||
"model_version": model_field_to_node_input(IO.COMBO, TripoMultiviewToModelRequest, "model_version", enum_type=TripoModelVersion),
|
||||
"orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation),
|
||||
"texture": ("BOOLEAN", {"default": True}),
|
||||
"pbr": ("BOOLEAN", {"default": True}),
|
||||
"model_seed": ("INT", {"default": 42}),
|
||||
"texture_seed": ("INT", {"default": 42}),
|
||||
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
|
||||
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
|
||||
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
|
||||
"quad": ("BOOLEAN", {"default": False})
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
|
||||
RETURN_NAMES = ("model_file", "model task_id")
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
|
||||
if image is None:
|
||||
raise RuntimeError("front image for multiview is required")
|
||||
images = []
|
||||
image_dict = {
|
||||
"image": image,
|
||||
"image_left": image_left,
|
||||
"image_back": image_back,
|
||||
"image_right": image_right
|
||||
}
|
||||
if image_left is None and image_back is None and image_right is None:
|
||||
raise RuntimeError("At least one of left, back, or right image must be provided for multiview")
|
||||
for image_name in ["image", "image_left", "image_back", "image_right"]:
|
||||
image_ = image_dict[image_name]
|
||||
if image_ is not None:
|
||||
tripo_file = upload_image_to_tripo(image_, **kwargs)
|
||||
images.append(tripo_file)
|
||||
else:
|
||||
images.append(TripoFileEmptyReference())
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoMultiviewToModelRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoMultiviewToModelRequest(
|
||||
type=TripoTaskType.MULTIVIEW_TO_MODEL,
|
||||
files=images,
|
||||
model_version=model_version,
|
||||
orientation=orientation,
|
||||
texture=texture,
|
||||
pbr=pbr,
|
||||
model_seed=model_seed,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
texture_alignment=texture_alignment,
|
||||
face_limit=face_limit,
|
||||
quad=quad,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
class TripoTextureNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model_task_id": ("MODEL_TASK_ID",),
|
||||
},
|
||||
"optional": {
|
||||
"texture": ("BOOLEAN", {"default": True}),
|
||||
"pbr": ("BOOLEAN", {"default": True}),
|
||||
"texture_seed": ("INT", {"default": 42}),
|
||||
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
|
||||
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
|
||||
RETURN_NAMES = ("model_file", "model task_id")
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 80
|
||||
|
||||
def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoTextureModelRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoTextureModelRequest(
|
||||
original_model_task_id=model_task_id,
|
||||
texture=texture,
|
||||
pbr=pbr,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
texture_alignment=texture_alignment
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoRefineNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model_task_id": ("MODEL_TASK_ID", {
|
||||
"tooltip": "Must be a v1.4 Tripo model"
|
||||
}),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Refine a draft model created by v1.4 Tripo models only."
|
||||
|
||||
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
|
||||
RETURN_NAMES = ("model_file", "model task_id")
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 240
|
||||
|
||||
def generate_mesh(self, model_task_id, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoRefineModelRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoRefineModelRequest(
|
||||
draft_model_task_id=model_task_id
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoRigNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"original_model_task_id": ("MODEL_TASK_ID",),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING", "RIG_TASK_ID")
|
||||
RETURN_NAMES = ("model_file", "rig task_id")
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 180
|
||||
|
||||
def generate_mesh(self, original_model_task_id, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoAnimateRigRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoAnimateRigRequest(
|
||||
original_model_task_id=original_model_task_id,
|
||||
out_format="glb",
|
||||
spec="tripo"
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
class TripoRetargetNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"original_model_task_id": ("RIG_TASK_ID",),
|
||||
"animation": ([
|
||||
"preset:idle",
|
||||
"preset:walk",
|
||||
"preset:climb",
|
||||
"preset:jump",
|
||||
"preset:slash",
|
||||
"preset:shoot",
|
||||
"preset:hurt",
|
||||
"preset:fall",
|
||||
"preset:turn",
|
||||
],),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING", "RETARGET_TASK_ID")
|
||||
RETURN_NAMES = ("model_file", "retarget task_id")
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 30
|
||||
|
||||
def generate_mesh(self, animation, original_model_task_id, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoAnimateRetargetRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoAnimateRetargetRequest(
|
||||
original_model_task_id=original_model_task_id,
|
||||
animation=animation,
|
||||
out_format="glb",
|
||||
bake_animation=True
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
class TripoConversionNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"original_model_task_id": ("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID",),
|
||||
"format": (["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"],),
|
||||
},
|
||||
"optional": {
|
||||
"quad": ("BOOLEAN", {"default": False}),
|
||||
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
|
||||
"texture_size": ("INT", {"min": 128, "max": 4096, "default": 4096}),
|
||||
"texture_format": (["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], {"default": "JPEG"})
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, input_types):
|
||||
# The min and max of input1 and input2 are still validated because
|
||||
# we didn't take `input1` or `input2` as arguments
|
||||
if input_types["original_model_task_id"] not in ("MODEL_TASK_ID", "RIG_TASK_ID", "RETARGET_TASK_ID"):
|
||||
return "original_model_task_id must be MODEL_TASK_ID, RIG_TASK_ID or RETARGET_TASK_ID type"
|
||||
return True
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "generate_mesh"
|
||||
CATEGORY = "api node/3d/Tripo"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 30
|
||||
|
||||
def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
|
||||
if not original_model_task_id:
|
||||
raise RuntimeError("original_model_task_id is required")
|
||||
response = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
request_model=TripoConvertModelRequest,
|
||||
response_model=TripoTaskResponse,
|
||||
),
|
||||
request=TripoConvertModelRequest(
|
||||
original_model_task_id=original_model_task_id,
|
||||
format=format,
|
||||
quad=quad if quad else None,
|
||||
face_limit=face_limit if face_limit != -1 else None,
|
||||
texture_size=texture_size if texture_size != 4096 else None,
|
||||
texture_format=texture_format if texture_format != "JPEG" else None
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TripoTextToModelNode": TripoTextToModelNode,
|
||||
"TripoImageToModelNode": TripoImageToModelNode,
|
||||
"TripoMultiviewToModelNode": TripoMultiviewToModelNode,
|
||||
"TripoTextureNode": TripoTextureNode,
|
||||
"TripoRefineNode": TripoRefineNode,
|
||||
"TripoRigNode": TripoRigNode,
|
||||
"TripoRetargetNode": TripoRetargetNode,
|
||||
"TripoConversionNode": TripoConversionNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TripoTextToModelNode": "Tripo: Text to Model",
|
||||
"TripoImageToModelNode": "Tripo: Image to Model",
|
||||
"TripoMultiviewToModelNode": "Tripo: Multiview to Model",
|
||||
"TripoTextureNode": "Tripo: Texture model",
|
||||
"TripoRefineNode": "Tripo: Refine Draft model",
|
||||
"TripoRigNode": "Tripo: Rig model",
|
||||
"TripoRetargetNode": "Tripo: Retarget rigged model",
|
||||
"TripoConversionNode": "Tripo: Convert model",
|
||||
}
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import base64
|
||||
import requests
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
@@ -24,6 +25,8 @@ from comfy_api_nodes.apinode_utils import (
|
||||
tensor_to_base64_string
|
||||
)
|
||||
|
||||
AVERAGE_DURATION_VIDEO_GEN = 32
|
||||
|
||||
def convert_image_to_base64(image: torch.Tensor):
|
||||
if image is None:
|
||||
return None
|
||||
@@ -31,6 +34,22 @@ def convert_image_to_base64(image: torch.Tensor):
|
||||
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
|
||||
return tensor_to_base64_string(scaled_image)
|
||||
|
||||
|
||||
def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]:
|
||||
if (
|
||||
poll_response.response
|
||||
and hasattr(poll_response.response, "videos")
|
||||
and poll_response.response.videos
|
||||
and len(poll_response.response.videos) > 0
|
||||
):
|
||||
video = poll_response.response.videos[0]
|
||||
else:
|
||||
return None
|
||||
if hasattr(video, "gcsUri") and video.gcsUri:
|
||||
return str(video.gcsUri)
|
||||
return None
|
||||
|
||||
|
||||
class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos from text prompts using Google's Veo API.
|
||||
@@ -114,6 +133,8 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -133,7 +154,8 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
person_generation="ALLOW",
|
||||
seed=0,
|
||||
image=None,
|
||||
auth_token=None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Prepare the instances for the request
|
||||
instances = []
|
||||
@@ -179,7 +201,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
instances=instances,
|
||||
parameters=parameters
|
||||
),
|
||||
auth_token=auth_token
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
@@ -213,8 +235,11 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
request=Veo2GenVidPollRequest(
|
||||
operationName=operation_name
|
||||
),
|
||||
auth_token=auth_token,
|
||||
poll_interval=5.0
|
||||
auth_kwargs=kwargs,
|
||||
poll_interval=5.0,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
node_id=unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||
)
|
||||
|
||||
# Execute the polling operation
|
||||
|
||||
0
comfy_api_nodes/util/__init__.py
Normal file
0
comfy_api_nodes/util/__init__.py
Normal file
100
comfy_api_nodes/util/validation_utils.py
Normal file
100
comfy_api_nodes/util/validation_utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
|
||||
|
||||
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
|
||||
if len(image.shape) == 4:
|
||||
return image.shape[1], image.shape[2]
|
||||
elif len(image.shape) == 3:
|
||||
return image.shape[0], image.shape[1]
|
||||
else:
|
||||
raise ValueError("Invalid image tensor shape.")
|
||||
|
||||
|
||||
def validate_image_dimensions(
|
||||
image: torch.Tensor,
|
||||
min_width: Optional[int] = None,
|
||||
max_width: Optional[int] = None,
|
||||
min_height: Optional[int] = None,
|
||||
max_height: Optional[int] = None,
|
||||
):
|
||||
height, width = get_image_dimensions(image)
|
||||
|
||||
if min_width is not None and width < min_width:
|
||||
raise ValueError(f"Image width must be at least {min_width}px, got {width}px")
|
||||
if max_width is not None and width > max_width:
|
||||
raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
|
||||
if min_height is not None and height < min_height:
|
||||
raise ValueError(
|
||||
f"Image height must be at least {min_height}px, got {height}px"
|
||||
)
|
||||
if max_height is not None and height > max_height:
|
||||
raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
|
||||
|
||||
|
||||
def validate_image_aspect_ratio(
|
||||
image: torch.Tensor,
|
||||
min_aspect_ratio: Optional[float] = None,
|
||||
max_aspect_ratio: Optional[float] = None,
|
||||
):
|
||||
width, height = get_image_dimensions(image)
|
||||
aspect_ratio = width / height
|
||||
|
||||
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
|
||||
raise ValueError(
|
||||
f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}"
|
||||
)
|
||||
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
|
||||
raise ValueError(
|
||||
f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}"
|
||||
)
|
||||
|
||||
|
||||
def validate_video_dimensions(
|
||||
video: VideoInput,
|
||||
min_width: Optional[int] = None,
|
||||
max_width: Optional[int] = None,
|
||||
min_height: Optional[int] = None,
|
||||
max_height: Optional[int] = None,
|
||||
):
|
||||
try:
|
||||
width, height = video.get_dimensions()
|
||||
except Exception as e:
|
||||
logging.error("Error getting dimensions of video: %s", e)
|
||||
return
|
||||
|
||||
if min_width is not None and width < min_width:
|
||||
raise ValueError(f"Video width must be at least {min_width}px, got {width}px")
|
||||
if max_width is not None and width > max_width:
|
||||
raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
|
||||
if min_height is not None and height < min_height:
|
||||
raise ValueError(
|
||||
f"Video height must be at least {min_height}px, got {height}px"
|
||||
)
|
||||
if max_height is not None and height > max_height:
|
||||
raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
|
||||
|
||||
|
||||
def validate_video_duration(
|
||||
video: VideoInput,
|
||||
min_duration: Optional[float] = None,
|
||||
max_duration: Optional[float] = None,
|
||||
):
|
||||
try:
|
||||
duration = video.get_duration()
|
||||
except Exception as e:
|
||||
logging.error("Error getting duration of video: %s", e)
|
||||
return
|
||||
|
||||
epsilon = 0.0001
|
||||
if min_duration is not None and min_duration - epsilon > duration:
|
||||
raise ValueError(
|
||||
f"Video duration must be at least {min_duration}s, got {duration}s"
|
||||
)
|
||||
if max_duration is not None and duration > max_duration + epsilon:
|
||||
raise ValueError(
|
||||
f"Video duration must be at most {max_duration}s, got {duration}s"
|
||||
)
|
||||
76
comfy_extras/nodes_apg.py
Normal file
76
comfy_extras/nodes_apg.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
|
||||
def project(v0, v1):
|
||||
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
|
||||
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
return v0_parallel, v0_orthogonal
|
||||
|
||||
class APG:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
|
||||
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
|
||||
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
|
||||
def patch(self, model, eta, norm_threshold, momentum):
|
||||
running_avg = 0
|
||||
prev_sigma = None
|
||||
|
||||
def pre_cfg_function(args):
|
||||
nonlocal running_avg, prev_sigma
|
||||
|
||||
if len(args["conds_out"]) == 1: return args["conds_out"]
|
||||
|
||||
cond = args["conds_out"][0]
|
||||
uncond = args["conds_out"][1]
|
||||
sigma = args["sigma"][0]
|
||||
cond_scale = args["cond_scale"]
|
||||
|
||||
if prev_sigma is not None and sigma > prev_sigma:
|
||||
running_avg = 0
|
||||
prev_sigma = sigma
|
||||
|
||||
guidance = cond - uncond
|
||||
|
||||
if momentum != 0:
|
||||
if not torch.is_tensor(running_avg):
|
||||
running_avg = guidance
|
||||
else:
|
||||
running_avg = momentum * running_avg + guidance
|
||||
guidance = running_avg
|
||||
|
||||
if norm_threshold > 0:
|
||||
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
||||
scale = torch.minimum(
|
||||
torch.ones_like(guidance_norm),
|
||||
norm_threshold / guidance_norm
|
||||
)
|
||||
guidance = guidance * scale
|
||||
|
||||
guidance_parallel, guidance_orthogonal = project(guidance, cond)
|
||||
modified_guidance = guidance_orthogonal + eta * guidance_parallel
|
||||
|
||||
modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale
|
||||
|
||||
return [modified_cond, uncond] + args["conds_out"][2:]
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||
return (m,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"APG": APG,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"APG": "Adaptive Projected Guidance",
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import av
|
||||
import torchaudio
|
||||
import torch
|
||||
import comfy.model_management
|
||||
@@ -7,7 +8,6 @@ import folder_paths
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import struct
|
||||
import random
|
||||
import hashlib
|
||||
import node_helpers
|
||||
@@ -90,60 +90,118 @@ class VAEDecodeAudio:
|
||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||
|
||||
|
||||
def create_vorbis_comment_block(comment_dict, last_block):
|
||||
vendor_string = b'ComfyUI'
|
||||
vendor_length = len(vendor_string)
|
||||
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
|
||||
|
||||
comments = []
|
||||
for key, value in comment_dict.items():
|
||||
comment = f"{key}={value}".encode('utf-8')
|
||||
comments.append(struct.pack('<I', len(comment)) + comment)
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results: list[FileLocator] = []
|
||||
|
||||
user_comment_list_length = len(comments)
|
||||
user_comments = b''.join(comments)
|
||||
# Prepare metadata dictionary
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
comment_data = struct.pack('<I', vendor_length) + vendor_string + struct.pack('<I', user_comment_list_length) + user_comments
|
||||
if last_block:
|
||||
id = b'\x84'
|
||||
else:
|
||||
id = b'\x04'
|
||||
comment_block = id + struct.pack('>I', len(comment_data))[1:] + comment_data
|
||||
# Opus supported sample rates
|
||||
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
|
||||
|
||||
return comment_block
|
||||
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
||||
output_path = os.path.join(full_output_folder, file)
|
||||
|
||||
def insert_or_replace_vorbis_comment(flac_io, comment_dict):
|
||||
if len(comment_dict) == 0:
|
||||
return flac_io
|
||||
# Use original sample rate initially
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
flac_io.seek(4)
|
||||
# Handle Opus sample rate requirements
|
||||
if format == "opus":
|
||||
if sample_rate > 48000:
|
||||
sample_rate = 48000
|
||||
elif sample_rate not in OPUS_RATES:
|
||||
# Find the next highest supported rate
|
||||
for rate in sorted(OPUS_RATES):
|
||||
if rate > sample_rate:
|
||||
sample_rate = rate
|
||||
break
|
||||
if sample_rate not in OPUS_RATES: # Fallback if still not supported
|
||||
sample_rate = 48000
|
||||
|
||||
blocks = []
|
||||
last_block = False
|
||||
# Resample if necessary
|
||||
if sample_rate != audio["sample_rate"]:
|
||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||
|
||||
while not last_block:
|
||||
header = flac_io.read(4)
|
||||
last_block = (header[0] & 0x80) != 0
|
||||
block_type = header[0] & 0x7F
|
||||
block_length = struct.unpack('>I', b'\x00' + header[1:])[0]
|
||||
block_data = flac_io.read(block_length)
|
||||
# Create in-memory WAV buffer
|
||||
wav_buffer = io.BytesIO()
|
||||
torchaudio.save(wav_buffer, waveform, sample_rate, format="WAV")
|
||||
wav_buffer.seek(0) # Rewind for reading
|
||||
|
||||
if block_type == 4 or block_type == 1:
|
||||
pass
|
||||
else:
|
||||
header = bytes([(header[0] & (~0x80))]) + header[1:]
|
||||
blocks.append(header + block_data)
|
||||
# Use PyAV to convert and add metadata
|
||||
input_container = av.open(wav_buffer)
|
||||
|
||||
blocks.append(create_vorbis_comment_block(comment_dict, last_block=True))
|
||||
# Create output with specified format
|
||||
output_buffer = io.BytesIO()
|
||||
output_container = av.open(output_buffer, mode='w', format=format)
|
||||
|
||||
new_flac_io = io.BytesIO()
|
||||
new_flac_io.write(b'fLaC')
|
||||
for block in blocks:
|
||||
new_flac_io.write(block)
|
||||
# Set metadata on the container
|
||||
for key, value in metadata.items():
|
||||
output_container.metadata[key] = value
|
||||
|
||||
new_flac_io.write(flac_io.read())
|
||||
return new_flac_io
|
||||
# Set up the output stream with appropriate properties
|
||||
input_container.streams.audio[0]
|
||||
if format == "opus":
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
||||
if quality == "64k":
|
||||
out_stream.bit_rate = 64000
|
||||
elif quality == "96k":
|
||||
out_stream.bit_rate = 96000
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "192k":
|
||||
out_stream.bit_rate = 192000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
elif format == "mp3":
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
||||
if quality == "V0":
|
||||
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||
out_stream.codec_context.qscale = 1
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
else: #format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||
|
||||
|
||||
# Copy frames from input to output
|
||||
for frame in input_container.decode(audio=0):
|
||||
frame.pts = None # Let PyAV handle timestamps
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output_container.mux(out_stream.encode(None))
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
|
||||
# Write the output to file
|
||||
output_buffer.seek(0)
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(output_buffer.getbuffer())
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
|
||||
return { "ui": { "audio": results } }
|
||||
|
||||
class SaveAudio:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
@@ -153,50 +211,70 @@ class SaveAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_audio"
|
||||
FUNCTION = "save_flac"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results: list[FileLocator] = []
|
||||
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
class SaveAudioMP3:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.flac"
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
buff = io.BytesIO()
|
||||
torchaudio.save(buff, waveform, audio["sample_rate"], format="FLAC")
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_mp3"
|
||||
|
||||
buff = insert_or_replace_vorbis_comment(buff, metadata)
|
||||
OUTPUT_NODE = True
|
||||
|
||||
with open(os.path.join(full_output_folder, file), 'wb') as f:
|
||||
f.write(buff.getbuffer())
|
||||
CATEGORY = "audio"
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
||||
|
||||
return { "ui": { "audio": results } }
|
||||
class SaveAudioOpus:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_opus"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
||||
|
||||
class PreviewAudio(SaveAudio):
|
||||
def __init__(self):
|
||||
@@ -248,7 +326,20 @@ NODE_CLASS_MAPPINGS = {
|
||||
"VAEEncodeAudio": VAEEncodeAudio,
|
||||
"VAEDecodeAudio": VAEDecodeAudio,
|
||||
"SaveAudio": SaveAudio,
|
||||
"SaveAudioMP3": SaveAudioMP3,
|
||||
"SaveAudioOpus": SaveAudioOpus,
|
||||
"LoadAudio": LoadAudio,
|
||||
"PreviewAudio": PreviewAudio,
|
||||
"ConditioningStableAudio": ConditioningStableAudio,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"EmptyLatentAudio": "Empty Latent Audio",
|
||||
"VAEEncodeAudio": "VAE Encode Audio",
|
||||
"VAEDecodeAudio": "VAE Decode Audio",
|
||||
"PreviewAudio": "Preview Audio",
|
||||
"LoadAudio": "Load Audio",
|
||||
"SaveAudio": "Save Audio (FLAC)",
|
||||
"SaveAudioMP3": "Save Audio (MP3)",
|
||||
"SaveAudioOpus": "Save Audio (Opus)",
|
||||
}
|
||||
|
||||
218
comfy_extras/nodes_camera_trajectory.py
Normal file
218
comfy_extras/nodes_camera_trajectory.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import nodes
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
|
||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
CAMERA_DICT = {
|
||||
"base_T_norm": 1.5,
|
||||
"base_angle": np.pi/3,
|
||||
"Static": { "angle":[0., 0., 0.], "T":[0., 0., 0.]},
|
||||
"Pan Up": { "angle":[0., 0., 0.], "T":[0., -1., 0.]},
|
||||
"Pan Down": { "angle":[0., 0., 0.], "T":[0.,1.,0.]},
|
||||
"Pan Left": { "angle":[0., 0., 0.], "T":[-1.,0.,0.]},
|
||||
"Pan Right": { "angle":[0., 0., 0.], "T": [1.,0.,0.]},
|
||||
"Zoom In": { "angle":[0., 0., 0.], "T": [0.,0.,2.]},
|
||||
"Zoom Out": { "angle":[0., 0., 0.], "T": [0.,0.,-2.]},
|
||||
"Anti Clockwise (ACW)": { "angle": [0., 0., -1.], "T":[0., 0., 0.]},
|
||||
"ClockWise (CW)": { "angle": [0., 0., 1.], "T":[0., 0., 0.]},
|
||||
}
|
||||
|
||||
|
||||
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
|
||||
|
||||
def get_relative_pose(cam_params):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
||||
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
||||
cam_to_origin = 0
|
||||
target_cam_c2w = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, -cam_to_origin],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1]
|
||||
])
|
||||
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
||||
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
||||
ret_poses = np.array(ret_poses, dtype=np.float32)
|
||||
return ret_poses
|
||||
|
||||
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
||||
|
||||
sample_wh_ratio = width / height
|
||||
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
||||
|
||||
if pose_wh_ratio > sample_wh_ratio:
|
||||
resized_ori_w = height * pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fx = resized_ori_w * cam_param.fx / width
|
||||
else:
|
||||
resized_ori_h = width / pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fy = resized_ori_h * cam_param.fy / height
|
||||
|
||||
intrinsic = np.asarray([[cam_param.fx * width,
|
||||
cam_param.fy * height,
|
||||
cam_param.cx * width,
|
||||
cam_param.cy * height]
|
||||
for cam_param in cam_params], dtype=np.float32)
|
||||
|
||||
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
||||
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
||||
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
||||
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
||||
plucker_embedding = plucker_embedding[None]
|
||||
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
||||
return plucker_embedding
|
||||
|
||||
class Camera(object):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
def __init__(self, entry):
|
||||
fx, fy, cx, cy = entry[1:5]
|
||||
self.fx = fx
|
||||
self.fy = fy
|
||||
self.cx = cx
|
||||
self.cy = cy
|
||||
c2w_mat = np.array(entry[7:]).reshape(4, 4)
|
||||
self.c2w_mat = c2w_mat
|
||||
self.w2c_mat = np.linalg.inv(c2w_mat)
|
||||
|
||||
def ray_condition(K, c2w, H, W, device):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
# c2w: B, V, 4, 4
|
||||
# K: B, V, 4
|
||||
|
||||
B = K.shape[0]
|
||||
|
||||
j, i = torch.meshgrid(
|
||||
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
||||
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
||||
indexing='ij'
|
||||
)
|
||||
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||
|
||||
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
||||
|
||||
zs = torch.ones_like(i) # [B, HxW]
|
||||
xs = (i - cx) / fx * zs
|
||||
ys = (j - cy) / fy * zs
|
||||
zs = zs.expand_as(ys)
|
||||
|
||||
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
||||
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
||||
|
||||
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
||||
rays_o = c2w[..., :3, 3] # B, V, 3
|
||||
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
||||
# c2w @ dirctions
|
||||
rays_dxo = torch.cross(rays_o, rays_d)
|
||||
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
||||
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
||||
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
||||
return plucker
|
||||
|
||||
def get_camera_motion(angle, T, speed, n=81):
|
||||
def compute_R_form_rad_angle(angles):
|
||||
theta_x, theta_y, theta_z = angles
|
||||
Rx = np.array([[1, 0, 0],
|
||||
[0, np.cos(theta_x), -np.sin(theta_x)],
|
||||
[0, np.sin(theta_x), np.cos(theta_x)]])
|
||||
|
||||
Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)],
|
||||
[0, 1, 0],
|
||||
[-np.sin(theta_y), 0, np.cos(theta_y)]])
|
||||
|
||||
Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0],
|
||||
[np.sin(theta_z), np.cos(theta_z), 0],
|
||||
[0, 0, 1]])
|
||||
|
||||
R = np.dot(Rz, np.dot(Ry, Rx))
|
||||
return R
|
||||
RT = []
|
||||
for i in range(n):
|
||||
_angle = (i/n)*speed*(CAMERA_DICT["base_angle"])*angle
|
||||
R = compute_R_form_rad_angle(_angle)
|
||||
_T=(i/n)*speed*(CAMERA_DICT["base_T_norm"])*(T.reshape(3,1))
|
||||
_RT = np.concatenate([R,_T], axis=1)
|
||||
RT.append(_RT)
|
||||
RT = np.stack(RT)
|
||||
return RT
|
||||
|
||||
class WanCameraEmbedding:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}),
|
||||
},
|
||||
"optional":{
|
||||
"speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}),
|
||||
"fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
|
||||
"fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
|
||||
"cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
|
||||
"cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT")
|
||||
RETURN_NAMES = ("camera_embedding","width","height","length")
|
||||
FUNCTION = "run"
|
||||
CATEGORY = "camera"
|
||||
|
||||
def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5):
|
||||
"""
|
||||
Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
|
||||
Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
|
||||
"""
|
||||
motion_list = [camera_pose]
|
||||
speed = speed
|
||||
angle = np.array(CAMERA_DICT[motion_list[0]]["angle"])
|
||||
T = np.array(CAMERA_DICT[motion_list[0]]["T"])
|
||||
RT = get_camera_motion(angle, T, speed, length)
|
||||
|
||||
trajs=[]
|
||||
for cp in RT.tolist():
|
||||
traj=[fx,fy,cx,cy,0,0]
|
||||
traj.extend(cp[0])
|
||||
traj.extend(cp[1])
|
||||
traj.extend(cp[2])
|
||||
traj.extend([0,0,0,1])
|
||||
trajs.append(traj)
|
||||
|
||||
cam_params = np.array([[float(x) for x in pose] for pose in trajs])
|
||||
cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1)
|
||||
control_camera_video = process_pose_params(cam_params, width=width, height=height)
|
||||
control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device())
|
||||
|
||||
control_camera_video = torch.concat(
|
||||
[
|
||||
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
|
||||
control_camera_video[:, :, 1:]
|
||||
], dim=2
|
||||
).transpose(1, 2)
|
||||
|
||||
# Reshape, transpose, and view into desired shape
|
||||
b, f, c, h, w = control_camera_video.shape
|
||||
control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
||||
control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
||||
|
||||
return (control_camera_video, width, height, length)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanCameraEmbedding": WanCameraEmbedding,
|
||||
}
|
||||
@@ -31,6 +31,7 @@ class T5TokenizerOptions:
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "_for_testing/conditioning"
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "set_options"
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ class HunyuanImageToVideo:
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
|
||||
"guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
|
||||
},
|
||||
"optional": {"start_image": ("IMAGE", ),
|
||||
}}
|
||||
@@ -101,10 +101,12 @@ class HunyuanImageToVideo:
|
||||
|
||||
if guidance_type == "v1 (concat)":
|
||||
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
|
||||
else:
|
||||
elif guidance_type == "v2 (replace)":
|
||||
cond = {'guiding_frame_index': 0}
|
||||
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
|
||||
out_latent["noise_mask"] = mask
|
||||
elif guidance_type == "custom":
|
||||
cond = {"ref_latent": concat_latent_image}
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, cond)
|
||||
|
||||
|
||||
@@ -10,6 +10,11 @@ from PIL.PngImagePlugin import PngInfo
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
from inspect import cleandoc
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
from comfy.comfy_types import FileLocator
|
||||
|
||||
@@ -71,6 +76,24 @@ class ImageFromBatch:
|
||||
s = s_in[batch_index:batch_index + length].clone()
|
||||
return (s,)
|
||||
|
||||
|
||||
class ImageAddNoise:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
|
||||
"strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "repeat"
|
||||
|
||||
CATEGORY = "image"
|
||||
|
||||
def repeat(self, image, seed, strength):
|
||||
generator = torch.manual_seed(seed)
|
||||
s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
|
||||
return (s,)
|
||||
|
||||
class SaveAnimatedWEBP:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
@@ -190,10 +213,291 @@ class SaveAnimatedPNG:
|
||||
|
||||
return { "ui": { "images": results, "animated": (True,)} }
|
||||
|
||||
class SVG:
|
||||
"""
|
||||
Stores SVG representations via a list of BytesIO objects.
|
||||
"""
|
||||
def __init__(self, data: list[BytesIO]):
|
||||
self.data = data
|
||||
|
||||
def combine(self, other: 'SVG') -> 'SVG':
|
||||
return SVG(self.data + other.data)
|
||||
|
||||
@staticmethod
|
||||
def combine_all(svgs: list['SVG']) -> 'SVG':
|
||||
all_svgs_list: list[BytesIO] = []
|
||||
for svg_item in svgs:
|
||||
all_svgs_list.extend(svg_item.data)
|
||||
return SVG(all_svgs_list)
|
||||
|
||||
|
||||
class ImageStitch:
|
||||
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",),
|
||||
"direction": (["right", "down", "left", "up"], {"default": "right"}),
|
||||
"match_image_size": ("BOOLEAN", {"default": True}),
|
||||
"spacing_width": (
|
||||
"INT",
|
||||
{"default": 0, "min": 0, "max": 1024, "step": 2},
|
||||
),
|
||||
"spacing_color": (
|
||||
["white", "black", "red", "green", "blue"],
|
||||
{"default": "white"},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image2": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stitch"
|
||||
CATEGORY = "image/transform"
|
||||
DESCRIPTION = """
|
||||
Stitches image2 to image1 in the specified direction.
|
||||
If image2 is not provided, returns image1 unchanged.
|
||||
Optional spacing can be added between images.
|
||||
"""
|
||||
|
||||
def stitch(
|
||||
self,
|
||||
image1,
|
||||
direction,
|
||||
match_image_size,
|
||||
spacing_width,
|
||||
spacing_color,
|
||||
image2=None,
|
||||
):
|
||||
if image2 is None:
|
||||
return (image1,)
|
||||
|
||||
# Handle batch size differences
|
||||
if image1.shape[0] != image2.shape[0]:
|
||||
max_batch = max(image1.shape[0], image2.shape[0])
|
||||
if image1.shape[0] < max_batch:
|
||||
image1 = torch.cat(
|
||||
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
|
||||
)
|
||||
if image2.shape[0] < max_batch:
|
||||
image2 = torch.cat(
|
||||
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
|
||||
)
|
||||
|
||||
# Match image sizes if requested
|
||||
if match_image_size:
|
||||
h1, w1 = image1.shape[1:3]
|
||||
h2, w2 = image2.shape[1:3]
|
||||
aspect_ratio = w2 / h2
|
||||
|
||||
if direction in ["left", "right"]:
|
||||
target_h, target_w = h1, int(h1 * aspect_ratio)
|
||||
else: # up, down
|
||||
target_w, target_h = w1, int(w1 / aspect_ratio)
|
||||
|
||||
image2 = comfy.utils.common_upscale(
|
||||
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
|
||||
).movedim(1, -1)
|
||||
|
||||
# When not matching sizes, pad to align non-concat dimensions
|
||||
if not match_image_size:
|
||||
h1, w1 = image1.shape[1:3]
|
||||
h2, w2 = image2.shape[1:3]
|
||||
|
||||
if direction in ["left", "right"]:
|
||||
# For horizontal concat, pad heights to match
|
||||
if h1 != h2:
|
||||
target_h = max(h1, h2)
|
||||
if h1 < target_h:
|
||||
pad_h = target_h - h1
|
||||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
|
||||
if h2 < target_h:
|
||||
pad_h = target_h - h2
|
||||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
|
||||
else: # up, down
|
||||
# For vertical concat, pad widths to match
|
||||
if w1 != w2:
|
||||
target_w = max(w1, w2)
|
||||
if w1 < target_w:
|
||||
pad_w = target_w - w1
|
||||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
|
||||
if w2 < target_w:
|
||||
pad_w = target_w - w2
|
||||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
|
||||
|
||||
# Ensure same number of channels
|
||||
if image1.shape[-1] != image2.shape[-1]:
|
||||
max_channels = max(image1.shape[-1], image2.shape[-1])
|
||||
if image1.shape[-1] < max_channels:
|
||||
image1 = torch.cat(
|
||||
[
|
||||
image1,
|
||||
torch.ones(
|
||||
*image1.shape[:-1],
|
||||
max_channels - image1.shape[-1],
|
||||
device=image1.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
if image2.shape[-1] < max_channels:
|
||||
image2 = torch.cat(
|
||||
[
|
||||
image2,
|
||||
torch.ones(
|
||||
*image2.shape[:-1],
|
||||
max_channels - image2.shape[-1],
|
||||
device=image2.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Add spacing if specified
|
||||
if spacing_width > 0:
|
||||
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
|
||||
|
||||
color_map = {
|
||||
"white": 1.0,
|
||||
"black": 0.0,
|
||||
"red": (1.0, 0.0, 0.0),
|
||||
"green": (0.0, 1.0, 0.0),
|
||||
"blue": (0.0, 0.0, 1.0),
|
||||
}
|
||||
color_val = color_map[spacing_color]
|
||||
|
||||
if direction in ["left", "right"]:
|
||||
spacing_shape = (
|
||||
image1.shape[0],
|
||||
max(image1.shape[1], image2.shape[1]),
|
||||
spacing_width,
|
||||
image1.shape[-1],
|
||||
)
|
||||
else:
|
||||
spacing_shape = (
|
||||
image1.shape[0],
|
||||
spacing_width,
|
||||
max(image1.shape[2], image2.shape[2]),
|
||||
image1.shape[-1],
|
||||
)
|
||||
|
||||
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
|
||||
if isinstance(color_val, tuple):
|
||||
for i, c in enumerate(color_val):
|
||||
if i < spacing.shape[-1]:
|
||||
spacing[..., i] = c
|
||||
if spacing.shape[-1] == 4: # Add alpha
|
||||
spacing[..., 3] = 1.0
|
||||
else:
|
||||
spacing[..., : min(3, spacing.shape[-1])] = color_val
|
||||
if spacing.shape[-1] == 4:
|
||||
spacing[..., 3] = 1.0
|
||||
|
||||
# Concatenate images
|
||||
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
|
||||
if spacing_width > 0:
|
||||
images.insert(1, spacing)
|
||||
|
||||
concat_dim = 2 if direction in ["left", "right"] else 1
|
||||
return (torch.cat(images, dim=concat_dim),)
|
||||
|
||||
|
||||
class SaveSVGNode:
|
||||
"""
|
||||
Save SVG files on disk.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
RETURN_TYPES = ()
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "save_svg"
|
||||
CATEGORY = "image/save" # Changed
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"svg": ("SVG",), # Changed
|
||||
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO"
|
||||
}
|
||||
}
|
||||
|
||||
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results = list()
|
||||
|
||||
# Prepare metadata JSON
|
||||
metadata_dict = {}
|
||||
if prompt is not None:
|
||||
metadata_dict["prompt"] = prompt
|
||||
if extra_pnginfo is not None:
|
||||
metadata_dict.update(extra_pnginfo)
|
||||
|
||||
# Convert metadata to JSON string
|
||||
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
|
||||
|
||||
for batch_number, svg_bytes in enumerate(svg.data):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.svg"
|
||||
|
||||
# Read SVG content
|
||||
svg_bytes.seek(0)
|
||||
svg_content = svg_bytes.read().decode('utf-8')
|
||||
|
||||
# Inject metadata if available
|
||||
if metadata_json:
|
||||
# Create metadata element with CDATA section
|
||||
metadata_element = f""" <metadata>
|
||||
<![CDATA[
|
||||
{metadata_json}
|
||||
]]>
|
||||
</metadata>
|
||||
"""
|
||||
# Insert metadata after opening svg tag using regex with a replacement function
|
||||
def replacement(match):
|
||||
# match.group(1) contains the captured <svg> tag
|
||||
return match.group(1) + '\n' + metadata_element
|
||||
|
||||
# Apply the substitution
|
||||
svg_content = re.sub(r'(<svg[^>]*>)', replacement, svg_content, flags=re.UNICODE)
|
||||
|
||||
# Write the modified SVG to file
|
||||
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
|
||||
svg_file.write(svg_content.encode('utf-8'))
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
return { "ui": { "images": results } }
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
"ImageFromBatch": ImageFromBatch,
|
||||
"ImageAddNoise": ImageAddNoise,
|
||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||
"SaveSVGNode": SaveSVGNode,
|
||||
"ImageStitch": ImageStitch,
|
||||
}
|
||||
|
||||
@@ -2,6 +2,10 @@ import nodes
|
||||
import folder_paths
|
||||
import os
|
||||
|
||||
from comfy.comfy_types import IO
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
|
||||
|
||||
def normalize_path(path):
|
||||
return path.replace('\\', '/')
|
||||
|
||||
@@ -12,7 +16,7 @@ class Load3D():
|
||||
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
|
||||
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.mtl', '.fbx', '.stl'))]
|
||||
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.fbx', '.stl'))]
|
||||
|
||||
return {"required": {
|
||||
"model_file": (sorted(files), {"file_upload": True}),
|
||||
@@ -21,8 +25,8 @@ class Load3D():
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info")
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
@@ -41,7 +45,14 @@ class Load3D():
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info']
|
||||
video = None
|
||||
|
||||
if image['recording'] != "":
|
||||
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
||||
|
||||
video = VideoFromFile(recording_video_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
|
||||
|
||||
class Load3DAnimation():
|
||||
@classmethod
|
||||
@@ -59,8 +70,8 @@ class Load3DAnimation():
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info")
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
@@ -77,7 +88,14 @@ class Load3DAnimation():
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, image['camera_info']
|
||||
video = None
|
||||
|
||||
if image['recording'] != "":
|
||||
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
||||
|
||||
video = VideoFromFile(recording_video_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, image['camera_info'], video
|
||||
|
||||
class Preview3D():
|
||||
@classmethod
|
||||
|
||||
360
comfy_extras/nodes_string.py
Normal file
360
comfy_extras/nodes_string.py
Normal file
@@ -0,0 +1,360 @@
|
||||
import re
|
||||
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
|
||||
class StringConcatenate():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string_a": (IO.STRING, {"multiline": True}),
|
||||
"string_b": (IO.STRING, {"multiline": True}),
|
||||
"delimiter": (IO.STRING, {"multiline": False, "default": ""})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string_a, string_b, delimiter, **kwargs):
|
||||
return delimiter.join((string_a, string_b)),
|
||||
|
||||
class StringSubstring():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"start": (IO.INT, {}),
|
||||
"end": (IO.INT, {}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, start, end, **kwargs):
|
||||
return string[start:end],
|
||||
|
||||
class StringLength():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.INT,)
|
||||
RETURN_NAMES = ("length",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, **kwargs):
|
||||
length = len(string)
|
||||
|
||||
return length,
|
||||
|
||||
class CaseConverter():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["UPPERCASE", "lowercase", "Capitalize", "Title Case"]})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, mode, **kwargs):
|
||||
if mode == "UPPERCASE":
|
||||
result = string.upper()
|
||||
elif mode == "lowercase":
|
||||
result = string.lower()
|
||||
elif mode == "Capitalize":
|
||||
result = string.capitalize()
|
||||
elif mode == "Title Case":
|
||||
result = string.title()
|
||||
else:
|
||||
result = string
|
||||
|
||||
return result,
|
||||
|
||||
|
||||
class StringTrim():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["Both", "Left", "Right"]})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, mode, **kwargs):
|
||||
if mode == "Both":
|
||||
result = string.strip()
|
||||
elif mode == "Left":
|
||||
result = string.lstrip()
|
||||
elif mode == "Right":
|
||||
result = string.rstrip()
|
||||
else:
|
||||
result = string
|
||||
|
||||
return result,
|
||||
|
||||
class StringReplace():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"find": (IO.STRING, {"multiline": True}),
|
||||
"replace": (IO.STRING, {"multiline": True})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, find, replace, **kwargs):
|
||||
result = string.replace(find, replace)
|
||||
return result,
|
||||
|
||||
|
||||
class StringContains():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"substring": (IO.STRING, {"multiline": True}),
|
||||
"case_sensitive": (IO.BOOLEAN, {"default": True})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
RETURN_NAMES = ("contains",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, substring, case_sensitive, **kwargs):
|
||||
if case_sensitive:
|
||||
contains = substring in string
|
||||
else:
|
||||
contains = substring.lower() in string.lower()
|
||||
|
||||
return contains,
|
||||
|
||||
|
||||
class StringCompare():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string_a": (IO.STRING, {"multiline": True}),
|
||||
"string_b": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["Starts With", "Ends With", "Equal"]}),
|
||||
"case_sensitive": (IO.BOOLEAN, {"default": True})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string_a, string_b, mode, case_sensitive, **kwargs):
|
||||
if case_sensitive:
|
||||
a = string_a
|
||||
b = string_b
|
||||
else:
|
||||
a = string_a.lower()
|
||||
b = string_b.lower()
|
||||
|
||||
if mode == "Equal":
|
||||
return a == b,
|
||||
elif mode == "Starts With":
|
||||
return a.startswith(b),
|
||||
elif mode == "Ends With":
|
||||
return a.endswith(b),
|
||||
|
||||
class RegexMatch():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
RETURN_NAMES = ("matches",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, **kwargs):
|
||||
flags = 0
|
||||
|
||||
if case_insensitive:
|
||||
flags |= re.IGNORECASE
|
||||
if multiline:
|
||||
flags |= re.MULTILINE
|
||||
if dotall:
|
||||
flags |= re.DOTALL
|
||||
|
||||
try:
|
||||
match = re.search(regex_pattern, string, flags)
|
||||
result = match is not None
|
||||
|
||||
except re.error:
|
||||
result = False
|
||||
|
||||
return result,
|
||||
|
||||
|
||||
class RegexExtract():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["First Match", "All Matches", "First Group", "All Groups"]}),
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False}),
|
||||
"group_index": (IO.INT, {"default": 1, "min": 0, "max": 100})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index, **kwargs):
|
||||
join_delimiter = "\n"
|
||||
|
||||
flags = 0
|
||||
if case_insensitive:
|
||||
flags |= re.IGNORECASE
|
||||
if multiline:
|
||||
flags |= re.MULTILINE
|
||||
if dotall:
|
||||
flags |= re.DOTALL
|
||||
|
||||
try:
|
||||
if mode == "First Match":
|
||||
match = re.search(regex_pattern, string, flags)
|
||||
if match:
|
||||
result = match.group(0)
|
||||
else:
|
||||
result = ""
|
||||
|
||||
elif mode == "All Matches":
|
||||
matches = re.findall(regex_pattern, string, flags)
|
||||
if matches:
|
||||
if isinstance(matches[0], tuple):
|
||||
result = join_delimiter.join([m[0] for m in matches])
|
||||
else:
|
||||
result = join_delimiter.join(matches)
|
||||
else:
|
||||
result = ""
|
||||
|
||||
elif mode == "First Group":
|
||||
match = re.search(regex_pattern, string, flags)
|
||||
if match and len(match.groups()) >= group_index:
|
||||
result = match.group(group_index)
|
||||
else:
|
||||
result = ""
|
||||
|
||||
elif mode == "All Groups":
|
||||
matches = re.finditer(regex_pattern, string, flags)
|
||||
results = []
|
||||
for match in matches:
|
||||
if match.groups() and len(match.groups()) >= group_index:
|
||||
results.append(match.group(group_index))
|
||||
result = join_delimiter.join(results)
|
||||
else:
|
||||
result = ""
|
||||
|
||||
except re.error:
|
||||
result = ""
|
||||
|
||||
return result,
|
||||
|
||||
|
||||
class RegexReplace():
|
||||
DESCRIPTION = "Find and replace text using regex patterns."
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"replace": (IO.STRING, {"multiline": True}),
|
||||
},
|
||||
"optional": {
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}),
|
||||
"count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs):
|
||||
flags = 0
|
||||
|
||||
if case_insensitive:
|
||||
flags |= re.IGNORECASE
|
||||
if multiline:
|
||||
flags |= re.MULTILINE
|
||||
if dotall:
|
||||
flags |= re.DOTALL
|
||||
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
|
||||
return result,
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StringConcatenate": StringConcatenate,
|
||||
"StringSubstring": StringSubstring,
|
||||
"StringLength": StringLength,
|
||||
"CaseConverter": CaseConverter,
|
||||
"StringTrim": StringTrim,
|
||||
"StringReplace": StringReplace,
|
||||
"StringContains": StringContains,
|
||||
"StringCompare": StringCompare,
|
||||
"RegexMatch": RegexMatch,
|
||||
"RegexExtract": RegexExtract,
|
||||
"RegexReplace": RegexReplace,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StringConcatenate": "Concatenate",
|
||||
"StringSubstring": "Substring",
|
||||
"StringLength": "Length",
|
||||
"CaseConverter": "Case Converter",
|
||||
"StringTrim": "Trim",
|
||||
"StringReplace": "Replace",
|
||||
"StringContains": "Contains",
|
||||
"StringCompare": "Compare",
|
||||
"RegexMatch": "Regex Match",
|
||||
"RegexExtract": "Regex Extract",
|
||||
"RegexReplace": "Regex Replace",
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
||||
|
||||
|
||||
class TorchCompileModel:
|
||||
@classmethod
|
||||
@@ -14,7 +15,7 @@ class TorchCompileModel:
|
||||
|
||||
def patch(self, model, backend):
|
||||
m = model.clone()
|
||||
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
|
||||
set_torch_compile_wrapper(model=m, backend=backend)
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@@ -268,8 +268,9 @@ class WanVaceToVideo:
|
||||
trim_latent = reference_image.shape[2]
|
||||
|
||||
mask = mask.unsqueeze(0)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True)
|
||||
|
||||
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
out_latent = {}
|
||||
@@ -297,6 +298,90 @@ class TrimVideoLatent:
|
||||
samples_out["samples"] = s1[:, :, trim_amount:]
|
||||
return (samples_out,)
|
||||
|
||||
class WanCameraImageToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"camera_conditions": ("WAN_CAMERA_EMBEDDING", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
||||
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
|
||||
|
||||
if camera_conditions is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions})
|
||||
negative = node_helpers.conditioning_set_values(negative, {'camera_conditions': camera_conditions})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
class WanPhantomSubjectToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"images": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, images):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
cond2 = negative
|
||||
if images is not None:
|
||||
images = comfy.utils.common_upscale(images[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
latent_images = []
|
||||
for i in images:
|
||||
latent_images += [vae.encode(i.unsqueeze(0)[:, :, :, :3])]
|
||||
concat_latent_image = torch.cat(latent_images, dim=2)
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"time_dim_concat": concat_latent_image})
|
||||
cond2 = node_helpers.conditioning_set_values(negative, {"time_dim_concat": concat_latent_image})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"time_dim_concat": comfy.latent_formats.Wan21().process_out(torch.zeros_like(concat_latent_image))})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, cond2, negative, out_latent)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanImageToVideo": WanImageToVideo,
|
||||
@@ -305,4 +390,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
||||
"WanVaceToVideo": WanVaceToVideo,
|
||||
"TrimVideoLatent": TrimVideoLatent,
|
||||
"WanCameraImageToVideo": WanCameraImageToVideo,
|
||||
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
|
||||
}
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.33"
|
||||
__version__ = "0.3.39"
|
||||
|
||||
11
execution.py
11
execution.py
@@ -146,6 +146,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
input_data_all[x] = [unique_id]
|
||||
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
return input_data_all, missing_keys
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
@@ -907,7 +909,6 @@ class PromptQueue:
|
||||
self.currently_running = {}
|
||||
self.history = {}
|
||||
self.flags = {}
|
||||
server.prompt_queue = self
|
||||
|
||||
def put(self, item):
|
||||
with self.mutex:
|
||||
@@ -952,6 +953,7 @@ class PromptQueue:
|
||||
self.history[prompt[1]].update(history_result)
|
||||
self.server.queue_updated()
|
||||
|
||||
# Note: slow
|
||||
def get_current_queue(self):
|
||||
with self.mutex:
|
||||
out = []
|
||||
@@ -959,6 +961,13 @@ class PromptQueue:
|
||||
out += [x]
|
||||
return (out, copy.deepcopy(self.queue))
|
||||
|
||||
# read-safe as long as queue items are immutable
|
||||
def get_current_queue_volatile(self):
|
||||
with self.mutex:
|
||||
running = [x for x in self.currently_running.values()]
|
||||
queued = copy.copy(self.queue)
|
||||
return (running, queued)
|
||||
|
||||
def get_tasks_remaining(self):
|
||||
with self.mutex:
|
||||
return len(self.queue) + len(self.currently_running)
|
||||
|
||||
28
fix_torch.py
28
fix_torch.py
@@ -1,28 +0,0 @@
|
||||
import importlib.util
|
||||
import shutil
|
||||
import os
|
||||
import ctypes
|
||||
import logging
|
||||
|
||||
|
||||
def fix_pytorch_libomp():
|
||||
"""
|
||||
Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed.
|
||||
"""
|
||||
torch_spec = importlib.util.find_spec("torch")
|
||||
for folder in torch_spec.submodule_search_locations:
|
||||
lib_folder = os.path.join(folder, "lib")
|
||||
test_file = os.path.join(lib_folder, "fbgemm.dll")
|
||||
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
|
||||
if os.path.exists(dest):
|
||||
break
|
||||
|
||||
with open(test_file, "rb") as f:
|
||||
contents = f.read()
|
||||
if b"libomp140.x86_64.dll" not in contents:
|
||||
break
|
||||
try:
|
||||
ctypes.cdll.LoadLibrary(test_file)
|
||||
except FileNotFoundError:
|
||||
logging.warning("Detected pytorch version with libomp issue, patching.")
|
||||
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
|
||||
@@ -275,7 +275,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
|
||||
|
||||
|
||||
|
||||
def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||
def get_full_path(folder_name: str, filename: str, allow_missing: bool = False) -> str | None:
|
||||
global folder_names_and_paths
|
||||
folder_name = map_legacy(folder_name)
|
||||
if folder_name not in folder_names_and_paths:
|
||||
@@ -288,6 +288,8 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||
return full_path
|
||||
elif os.path.islink(full_path):
|
||||
logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path))
|
||||
elif allow_missing:
|
||||
return full_path
|
||||
|
||||
return None
|
||||
|
||||
@@ -299,6 +301,27 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||
return full_path
|
||||
|
||||
|
||||
def get_relative_path(full_path: str) -> tuple[str, str] | None:
|
||||
"""Convert a full path back to a type-relative path.
|
||||
|
||||
Args:
|
||||
full_path: The full path to the file
|
||||
|
||||
Returns:
|
||||
tuple[str, str] | None: A tuple of (model_type, relative_path) if found, None otherwise
|
||||
"""
|
||||
global folder_names_and_paths
|
||||
full_path = os.path.normpath(full_path)
|
||||
|
||||
for model_type, (paths, _) in folder_names_and_paths.items():
|
||||
for base_path in paths:
|
||||
base_path = os.path.normpath(base_path)
|
||||
if full_path.startswith(base_path):
|
||||
relative_path = os.path.relpath(full_path, base_path)
|
||||
return model_type, relative_path
|
||||
|
||||
return None
|
||||
|
||||
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
||||
folder_name = map_legacy(folder_name)
|
||||
global folder_names_and_paths
|
||||
|
||||
19
main.py
19
main.py
@@ -125,13 +125,6 @@ if __name__ == "__main__":
|
||||
|
||||
import cuda_malloc
|
||||
|
||||
if args.windows_standalone_build:
|
||||
try:
|
||||
from fix_torch import fix_pytorch_libomp
|
||||
fix_pytorch_libomp()
|
||||
except:
|
||||
pass
|
||||
|
||||
import comfy.utils
|
||||
|
||||
import execution
|
||||
@@ -154,7 +147,6 @@ def cuda_malloc_warning():
|
||||
if cuda_malloc_warning:
|
||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
@@ -244,6 +236,13 @@ def cleanup_temp():
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
def setup_database():
|
||||
try:
|
||||
from app.database.db import init_db, dependencies_available
|
||||
if dependencies_available():
|
||||
init_db()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||
|
||||
def start_comfyui(asyncio_loop=None):
|
||||
"""
|
||||
@@ -267,18 +266,18 @@ def start_comfyui(asyncio_loop=None):
|
||||
asyncio_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
q = execution.PromptQueue(prompt_server)
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
setup_database()
|
||||
|
||||
prompt_server.add_routes()
|
||||
hijack_progress(prompt_server)
|
||||
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start()
|
||||
|
||||
if args.quick_test_for_ci:
|
||||
exit(0)
|
||||
|
||||
@@ -5,12 +5,18 @@ from comfy.cli_args import args
|
||||
|
||||
from PIL import ImageFile, UnidentifiedImageError
|
||||
|
||||
def conditioning_set_values(conditioning, values={}):
|
||||
def conditioning_set_values(conditioning, values={}, append=False):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
for k in values:
|
||||
n[1][k] = values[k]
|
||||
val = values[k]
|
||||
if append:
|
||||
old_val = n[1].get(k, None)
|
||||
if old_val is not None:
|
||||
val = old_val + val
|
||||
|
||||
n[1][k] = val
|
||||
c.append(n)
|
||||
|
||||
return c
|
||||
|
||||
21
nodes.py
21
nodes.py
@@ -1103,16 +1103,7 @@ class unCLIPConditioning:
|
||||
if strength == 0:
|
||||
return (conditioning, )
|
||||
|
||||
c = []
|
||||
for t in conditioning:
|
||||
o = t[1].copy()
|
||||
x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}
|
||||
if "unclip_conditioning" in o:
|
||||
o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x]
|
||||
else:
|
||||
o["unclip_conditioning"] = [x]
|
||||
n = [t[0], o]
|
||||
c.append(n)
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"unclip_conditioning": [{"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}]}, append=True)
|
||||
return (c, )
|
||||
|
||||
class GLIGENLoader:
|
||||
@@ -1940,7 +1931,7 @@ class ImagePadForOutpaint:
|
||||
|
||||
mask[top:top + d2, left:left + d3] = t
|
||||
|
||||
return (new_image, mask)
|
||||
return (new_image, mask.unsqueeze(0))
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
@@ -2070,6 +2061,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
||||
"ImageBatch": "Batch Images",
|
||||
"ImageCrop": "Image Crop",
|
||||
"ImageStitch": "Image Stitch",
|
||||
"ImageBlend": "Image Blend",
|
||||
"ImageBlur": "Image Blur",
|
||||
"ImageQuantize": "Image Quantize",
|
||||
@@ -2261,8 +2253,11 @@ def init_builtin_extra_nodes():
|
||||
"nodes_optimalsteps.py",
|
||||
"nodes_hidream.py",
|
||||
"nodes_fresca.py",
|
||||
"nodes_apg.py",
|
||||
"nodes_preview_any.py",
|
||||
"nodes_ace.py",
|
||||
"nodes_string.py",
|
||||
"nodes_camera_trajectory.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
@@ -2287,6 +2282,10 @@ def init_builtin_api_nodes():
|
||||
"nodes_pixverse.py",
|
||||
"nodes_stability.py",
|
||||
"nodes_pika.py",
|
||||
"nodes_runway.py",
|
||||
"nodes_tripo.py",
|
||||
"nodes_rodin.py",
|
||||
"nodes_gemini.py",
|
||||
]
|
||||
|
||||
if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.33"
|
||||
version = "0.3.39"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
comfyui-frontend-package==1.18.9
|
||||
comfyui-workflow-templates==0.1.11
|
||||
comfyui-frontend-package==1.21.3
|
||||
comfyui-workflow-templates==0.1.25
|
||||
comfyui-embedded-docs==0.2.0
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
@@ -17,6 +18,9 @@ Pillow
|
||||
scipy
|
||||
tqdm
|
||||
psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
blake3
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
|
||||
@@ -101,6 +101,14 @@ prompt_text = """
|
||||
|
||||
def queue_prompt(prompt):
|
||||
p = {"prompt": prompt}
|
||||
|
||||
# If the workflow contains API nodes, you can add a Comfy API key to the `extra_data`` field of the payload.
|
||||
# p["extra_data"] = {
|
||||
# "api_key_comfy_org": "comfyui-87d01e28d*******************************************************" # replace with real key
|
||||
# }
|
||||
# See: https://docs.comfy.org/tutorials/api-nodes/overview
|
||||
# Generate a key here: https://platform.comfy.org/login
|
||||
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
req = request.Request("http://127.0.0.1:8188/prompt", data=data)
|
||||
request.urlopen(req)
|
||||
|
||||
30
server.py
30
server.py
@@ -29,15 +29,17 @@ import comfy.model_management
|
||||
import node_helpers
|
||||
from comfyui_version import __version__
|
||||
from app.frontend_management import FrontendManager
|
||||
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
TEXT = 3
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
try:
|
||||
@@ -158,7 +160,7 @@ class PromptServer():
|
||||
self.custom_node_manager = CustomNodeManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = None
|
||||
self.prompt_queue = execution.PromptQueue(self)
|
||||
self.loop = loop
|
||||
self.messages = asyncio.Queue()
|
||||
self.client_session:Optional[aiohttp.ClientSession] = None
|
||||
@@ -225,7 +227,7 @@ class PromptServer():
|
||||
return response
|
||||
|
||||
@routes.get("/embeddings")
|
||||
def get_embeddings(self):
|
||||
def get_embeddings(request):
|
||||
embeddings = folder_paths.get_filename_list("embeddings")
|
||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||
|
||||
@@ -281,7 +283,6 @@ class PromptServer():
|
||||
a.update(f.read())
|
||||
b.update(image.file.read())
|
||||
image.file.seek(0)
|
||||
f.close()
|
||||
return a.hexdigest() == b.hexdigest()
|
||||
return False
|
||||
|
||||
@@ -620,7 +621,7 @@ class PromptServer():
|
||||
@routes.get("/queue")
|
||||
async def get_queue(request):
|
||||
queue_info = {}
|
||||
current_queue = self.prompt_queue.get_current_queue()
|
||||
current_queue = self.prompt_queue.get_current_queue_volatile()
|
||||
queue_info['queue_running'] = current_queue[0]
|
||||
queue_info['queue_pending'] = current_queue[1]
|
||||
return web.json_response(queue_info)
|
||||
@@ -745,6 +746,13 @@ class PromptServer():
|
||||
web.static('/templates', workflow_templates_path)
|
||||
])
|
||||
|
||||
# Serve embedded documentation from the package
|
||||
embedded_docs_path = FrontendManager.embedded_docs_path()
|
||||
if embedded_docs_path:
|
||||
self.app.add_routes([
|
||||
web.static('/docs', embedded_docs_path)
|
||||
])
|
||||
|
||||
self.app.add_routes([
|
||||
web.static('/', self.web_root),
|
||||
])
|
||||
@@ -878,3 +886,15 @@ class PromptServer():
|
||||
logging.warning(traceback.format_exc())
|
||||
|
||||
return json_data
|
||||
|
||||
def send_progress_text(
|
||||
self, text: Union[bytes, bytearray, str], node_id: str, sid=None
|
||||
):
|
||||
if isinstance(text, str):
|
||||
text = text.encode("utf-8")
|
||||
node_id_bytes = str(node_id).encode("utf-8")
|
||||
|
||||
# Pack the node_id length as a 4-byte unsigned integer, followed by the node_id bytes
|
||||
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
|
||||
|
||||
self.send_sync(BinaryEventTypes.TEXT, message, sid)
|
||||
|
||||
253
tests-unit/app_test/model_processor_test.py
Normal file
253
tests-unit/app_test/model_processor_test.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.model_processor import ModelProcessor
|
||||
from app.database.models import Model, Base
|
||||
import os
|
||||
|
||||
# Test data constants
|
||||
TEST_MODEL_TYPE = "checkpoints"
|
||||
TEST_URL = "http://example.com/model.safetensors"
|
||||
TEST_FILE_NAME = "model.safetensors"
|
||||
TEST_EXPECTED_HASH = "abc123"
|
||||
TEST_DESTINATION_PATH = "/path/to/model.safetensors"
|
||||
|
||||
|
||||
def create_test_model(session, file_name, model_type, hash_value, file_size=1000, source_url=None):
|
||||
"""Helper to create a test model in the database."""
|
||||
model = Model(path=file_name, type=model_type, hash=hash_value, file_size=file_size, source_url=source_url)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
return model
|
||||
|
||||
|
||||
def setup_mock_hash_calculation(model_processor, hash_value):
|
||||
"""Helper to setup hash calculation mocks."""
|
||||
mock_hash = MagicMock()
|
||||
mock_hash.hexdigest.return_value = hash_value
|
||||
return patch.object(model_processor, "_get_hasher", return_value=mock_hash)
|
||||
|
||||
|
||||
def verify_model_in_db(session, file_name, expected_hash=None, expected_type=None):
|
||||
"""Helper to verify model exists in database with correct attributes."""
|
||||
db_model = session.query(Model).filter_by(path=file_name).first()
|
||||
assert db_model is not None
|
||||
if expected_hash:
|
||||
assert db_model.hash == expected_hash
|
||||
if expected_type:
|
||||
assert db_model.type == expected_type
|
||||
return db_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_engine():
|
||||
# Configure in-memory database
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(engine)
|
||||
yield engine
|
||||
Base.metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session(db_engine):
|
||||
Session = sessionmaker(bind=db_engine)
|
||||
session = Session()
|
||||
yield session
|
||||
session.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_relative_path():
|
||||
with patch("app.model_processor.get_relative_path") as mock:
|
||||
mock.side_effect = lambda path: (TEST_MODEL_TYPE, os.path.basename(path))
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_full_path():
|
||||
with patch("app.model_processor.get_full_path") as mock:
|
||||
mock.return_value = TEST_DESTINATION_PATH
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_processor(db_session, mock_get_relative_path, mock_get_full_path):
|
||||
with patch("app.model_processor.create_session", return_value=db_session):
|
||||
with patch("app.model_processor.can_create_session", return_value=True):
|
||||
processor = ModelProcessor()
|
||||
# Setup test state
|
||||
processor.removed_files = []
|
||||
processor.downloaded_files = []
|
||||
processor.file_exists = {}
|
||||
|
||||
def mock_download_file(url, destination_path, hasher):
|
||||
processor.downloaded_files.append((url, destination_path))
|
||||
processor.file_exists[destination_path] = True
|
||||
# Simulate writing some data to the file
|
||||
test_data = b"test data"
|
||||
hasher.update(test_data)
|
||||
|
||||
def mock_remove_file(file_path):
|
||||
processor.removed_files.append(file_path)
|
||||
if file_path in processor.file_exists:
|
||||
del processor.file_exists[file_path]
|
||||
|
||||
# Setup common patches
|
||||
file_exists_patch = patch.object(
|
||||
processor,
|
||||
"_file_exists",
|
||||
side_effect=lambda path: processor.file_exists.get(path, False),
|
||||
)
|
||||
file_size_patch = patch.object(
|
||||
processor,
|
||||
"_get_file_size",
|
||||
side_effect=lambda path: (
|
||||
1000 if processor.file_exists.get(path, False) else 0
|
||||
),
|
||||
)
|
||||
download_file_patch = patch.object(
|
||||
processor, "_download_file", side_effect=mock_download_file
|
||||
)
|
||||
remove_file_patch = patch.object(
|
||||
processor, "_remove_file", side_effect=mock_remove_file
|
||||
)
|
||||
|
||||
with (
|
||||
file_exists_patch,
|
||||
file_size_patch,
|
||||
download_file_patch,
|
||||
remove_file_patch,
|
||||
):
|
||||
yield processor
|
||||
|
||||
|
||||
def test_ensure_downloaded_invalid_extension(model_processor):
|
||||
# Ensure that an unsupported file extension raises an error to prevent unsafe file downloads
|
||||
with pytest.raises(ValueError, match="Unsupported unsafe file for download"):
|
||||
model_processor.ensure_downloaded(TEST_MODEL_TYPE, TEST_URL, "model.exe")
|
||||
|
||||
|
||||
def test_ensure_downloaded_existing_file_with_hash(model_processor, db_session):
|
||||
# Ensure that a file with the same hash but from a different source is not downloaded again
|
||||
SOURCE_URL = "https://example.com/other.sft"
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH, source_url=SOURCE_URL)
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
result = model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
||||
)
|
||||
|
||||
assert result == TEST_DESTINATION_PATH
|
||||
model = verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
||||
assert model.source_url == SOURCE_URL # Ensure the source URL is not overwritten
|
||||
|
||||
|
||||
def test_ensure_downloaded_existing_file_hash_mismatch(model_processor, db_session):
|
||||
# Ensure that a file with a different hash raises an error
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, "different_hash")
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
with pytest.raises(ValueError, match="File .* exists with hash .* but expected .*"):
|
||||
model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_downloaded_new_file(model_processor, db_session):
|
||||
# Ensure that a new file is downloaded
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = False
|
||||
|
||||
with setup_mock_hash_calculation(model_processor, TEST_EXPECTED_HASH):
|
||||
result = model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
||||
)
|
||||
|
||||
assert result == TEST_DESTINATION_PATH
|
||||
assert len(model_processor.downloaded_files) == 1
|
||||
assert model_processor.downloaded_files[0] == (TEST_URL, TEST_DESTINATION_PATH)
|
||||
assert model_processor.file_exists[TEST_DESTINATION_PATH]
|
||||
verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
||||
|
||||
|
||||
def test_ensure_downloaded_hash_mismatch(model_processor, db_session):
|
||||
# Ensure that download that results in a different hash raises an error
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = False
|
||||
|
||||
with setup_mock_hash_calculation(model_processor, "different_hash"):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Downloaded file hash .* does not match expected hash .*",
|
||||
):
|
||||
model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE,
|
||||
TEST_URL,
|
||||
TEST_FILE_NAME,
|
||||
TEST_EXPECTED_HASH,
|
||||
)
|
||||
|
||||
assert len(model_processor.removed_files) == 1
|
||||
assert model_processor.removed_files[0] == TEST_DESTINATION_PATH
|
||||
assert TEST_DESTINATION_PATH not in model_processor.file_exists
|
||||
assert db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() is None
|
||||
|
||||
|
||||
def test_process_file_without_hash(model_processor, db_session):
|
||||
# Test processing file without provided hash
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH):
|
||||
result = model_processor.process_file(TEST_DESTINATION_PATH)
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
|
||||
|
||||
def test_retrieve_model_by_hash(model_processor, db_session):
|
||||
# Test retrieving model by hash
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH)
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
|
||||
|
||||
def test_retrieve_model_by_hash_and_type(model_processor, db_session):
|
||||
# Test retrieving model by hash and type
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
assert result.type == TEST_MODEL_TYPE
|
||||
|
||||
|
||||
def test_retrieve_hash(model_processor, db_session):
|
||||
# Test retrieving hash for existing model
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
with patch.object(
|
||||
model_processor,
|
||||
"_validate_path",
|
||||
return_value=(TEST_MODEL_TYPE, TEST_FILE_NAME),
|
||||
):
|
||||
result = model_processor.retrieve_hash(TEST_DESTINATION_PATH, TEST_MODEL_TYPE)
|
||||
assert result == TEST_EXPECTED_HASH
|
||||
|
||||
|
||||
def test_validate_file_extension_valid_extensions(model_processor):
|
||||
# Test all valid file extensions
|
||||
valid_extensions = [".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"]
|
||||
for ext in valid_extensions:
|
||||
model_processor._validate_file_extension(f"test{ext}") # Should not raise
|
||||
|
||||
|
||||
def test_process_file_existing_without_source_url(model_processor, db_session):
|
||||
# Test processing an existing file that needs its source URL updated
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL)
|
||||
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
assert result.source_url == TEST_URL
|
||||
|
||||
db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first()
|
||||
assert db_model.source_url == TEST_URL
|
||||
239
tests-unit/comfy_api_test/video_types_test.py
Normal file
239
tests-unit/comfy_api_test/video_types_test.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import pytest
|
||||
import torch
|
||||
import tempfile
|
||||
import os
|
||||
import av
|
||||
import io
|
||||
from fractions import Fraction
|
||||
from comfy_api.input_impl.video_types import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.util.video_types import VideoComponents
|
||||
from comfy_api.input.basic_types import AudioInput
|
||||
from av.error import InvalidDataError
|
||||
|
||||
EPSILON = 0.0001
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_images():
|
||||
"""3-frame 2x2 RGB video tensor"""
|
||||
return torch.rand(3, 2, 2, 3)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_audio():
|
||||
"""Stereo audio with 44.1kHz sample rate"""
|
||||
return AudioInput(
|
||||
{
|
||||
"waveform": torch.rand(1, 2, 1000),
|
||||
"sample_rate": 44100,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_components(sample_images, sample_audio):
|
||||
"""VideoComponents with images, audio, and metadata"""
|
||||
return VideoComponents(
|
||||
images=sample_images,
|
||||
audio=sample_audio,
|
||||
frame_rate=Fraction(30),
|
||||
metadata={"test": "metadata"},
|
||||
)
|
||||
|
||||
|
||||
def create_test_video(width=4, height=4, frames=3, fps=30):
|
||||
"""Helper to create a temporary video file"""
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
||||
with av.open(tmp.name, mode="w") as container:
|
||||
stream = container.add_stream("h264", rate=fps)
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
for i in range(frames):
|
||||
frame = av.VideoFrame.from_ndarray(
|
||||
torch.ones(height, width, 3, dtype=torch.uint8).numpy() * (i * 85),
|
||||
format="rgb24",
|
||||
)
|
||||
frame = frame.reformat(format="yuv420p")
|
||||
packet = stream.encode(frame)
|
||||
container.mux(packet)
|
||||
|
||||
# Flush
|
||||
packet = stream.encode(None)
|
||||
container.mux(packet)
|
||||
|
||||
return tmp.name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_video_file():
|
||||
"""4x4 video with 3 frames at 30fps"""
|
||||
file_path = create_test_video()
|
||||
yield file_path
|
||||
os.unlink(file_path)
|
||||
|
||||
|
||||
def test_video_from_components_get_duration(video_components):
|
||||
"""Duration calculated correctly from frame count and frame rate"""
|
||||
video = VideoFromComponents(video_components)
|
||||
duration = video.get_duration()
|
||||
|
||||
expected_duration = 3.0 / 30.0
|
||||
assert duration == pytest.approx(expected_duration)
|
||||
|
||||
|
||||
def test_video_from_components_get_duration_different_frame_rates(sample_images):
|
||||
"""Duration correct for different frame rates including fractional"""
|
||||
# Test with 60 fps
|
||||
components_60fps = VideoComponents(images=sample_images, frame_rate=Fraction(60))
|
||||
video_60fps = VideoFromComponents(components_60fps)
|
||||
assert video_60fps.get_duration() == pytest.approx(3.0 / 60.0)
|
||||
|
||||
# Test with fractional frame rate (23.976fps)
|
||||
components_frac = VideoComponents(
|
||||
images=sample_images, frame_rate=Fraction(24000, 1001)
|
||||
)
|
||||
video_frac = VideoFromComponents(components_frac)
|
||||
expected_frac = 3.0 / (24000.0 / 1001.0)
|
||||
assert video_frac.get_duration() == pytest.approx(expected_frac)
|
||||
|
||||
|
||||
def test_video_from_components_get_duration_empty_video():
|
||||
"""Duration is zero for empty video"""
|
||||
empty_components = VideoComponents(
|
||||
images=torch.zeros(0, 2, 2, 3), frame_rate=Fraction(30)
|
||||
)
|
||||
video = VideoFromComponents(empty_components)
|
||||
assert video.get_duration() == 0.0
|
||||
|
||||
|
||||
def test_video_from_components_get_dimensions(video_components):
|
||||
"""Dimensions returned correctly from image tensor shape"""
|
||||
video = VideoFromComponents(video_components)
|
||||
width, height = video.get_dimensions()
|
||||
assert width == 2
|
||||
assert height == 2
|
||||
|
||||
|
||||
def test_video_from_file_get_duration(simple_video_file):
|
||||
"""Duration extracted from file metadata"""
|
||||
video = VideoFromFile(simple_video_file)
|
||||
duration = video.get_duration()
|
||||
assert duration == pytest.approx(0.1, abs=0.01)
|
||||
|
||||
|
||||
def test_video_from_file_get_dimensions(simple_video_file):
|
||||
"""Dimensions read from stream without decoding frames"""
|
||||
video = VideoFromFile(simple_video_file)
|
||||
width, height = video.get_dimensions()
|
||||
assert width == 4
|
||||
assert height == 4
|
||||
|
||||
|
||||
def test_video_from_file_bytesio_input():
|
||||
"""VideoFromFile works with BytesIO input"""
|
||||
buffer = io.BytesIO()
|
||||
with av.open(buffer, mode="w", format="mp4") as container:
|
||||
stream = container.add_stream("h264", rate=30)
|
||||
stream.width = 2
|
||||
stream.height = 2
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
frame = av.VideoFrame.from_ndarray(
|
||||
torch.zeros(2, 2, 3, dtype=torch.uint8).numpy(), format="rgb24"
|
||||
)
|
||||
frame = frame.reformat(format="yuv420p")
|
||||
packet = stream.encode(frame)
|
||||
container.mux(packet)
|
||||
packet = stream.encode(None)
|
||||
container.mux(packet)
|
||||
|
||||
buffer.seek(0)
|
||||
video = VideoFromFile(buffer)
|
||||
|
||||
assert video.get_dimensions() == (2, 2)
|
||||
assert video.get_duration() == pytest.approx(1 / 30, abs=0.01)
|
||||
|
||||
|
||||
def test_video_from_file_invalid_file_error():
|
||||
"""InvalidDataError raised for non-video files"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp:
|
||||
tmp.write(b"not a video file")
|
||||
tmp.flush()
|
||||
tmp_name = tmp.name
|
||||
|
||||
try:
|
||||
with pytest.raises(InvalidDataError):
|
||||
video = VideoFromFile(tmp_name)
|
||||
video.get_dimensions()
|
||||
finally:
|
||||
os.unlink(tmp_name)
|
||||
|
||||
|
||||
def test_video_from_file_audio_only_error():
|
||||
"""ValueError raised for audio-only files"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp:
|
||||
tmp_name = tmp.name
|
||||
|
||||
try:
|
||||
with av.open(tmp_name, mode="w") as container:
|
||||
stream = container.add_stream("aac", rate=44100)
|
||||
stream.sample_rate = 44100
|
||||
stream.format = "fltp"
|
||||
|
||||
audio_data = torch.zeros(1, 1024).numpy()
|
||||
audio_frame = av.AudioFrame.from_ndarray(
|
||||
audio_data, format="fltp", layout="mono"
|
||||
)
|
||||
audio_frame.sample_rate = 44100
|
||||
audio_frame.pts = 0
|
||||
packet = stream.encode(audio_frame)
|
||||
container.mux(packet)
|
||||
|
||||
for packet in stream.encode(None):
|
||||
container.mux(packet)
|
||||
|
||||
with pytest.raises(ValueError, match="No video stream found"):
|
||||
video = VideoFromFile(tmp_name)
|
||||
video.get_dimensions()
|
||||
finally:
|
||||
os.unlink(tmp_name)
|
||||
|
||||
|
||||
def test_single_frame_video():
|
||||
"""Single frame video has correct duration"""
|
||||
components = VideoComponents(
|
||||
images=torch.rand(1, 10, 10, 3), frame_rate=Fraction(1)
|
||||
)
|
||||
video = VideoFromComponents(components)
|
||||
assert video.get_duration() == 1.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"frame_rate,expected_fps",
|
||||
[
|
||||
(Fraction(24000, 1001), 24000 / 1001),
|
||||
(Fraction(30000, 1001), 30000 / 1001),
|
||||
(Fraction(25, 1), 25.0),
|
||||
(Fraction(50, 2), 25.0),
|
||||
],
|
||||
)
|
||||
def test_fractional_frame_rates(frame_rate, expected_fps):
|
||||
"""Duration calculated correctly for various fractional frame rates"""
|
||||
components = VideoComponents(images=torch.rand(100, 4, 4, 3), frame_rate=frame_rate)
|
||||
video = VideoFromComponents(components)
|
||||
duration = video.get_duration()
|
||||
expected_duration = 100.0 / expected_fps
|
||||
assert duration == pytest.approx(expected_duration)
|
||||
|
||||
|
||||
def test_duration_consistency(video_components):
|
||||
"""get_duration() consistent with manual calculation from components"""
|
||||
video = VideoFromComponents(video_components)
|
||||
|
||||
duration = video.get_duration()
|
||||
components = video.get_components()
|
||||
manual_duration = float(components.images.shape[0] / components.frame_rate)
|
||||
|
||||
assert duration == pytest.approx(manual_duration)
|
||||
0
tests-unit/comfy_extras_test/__init__.py
Normal file
0
tests-unit/comfy_extras_test/__init__.py
Normal file
240
tests-unit/comfy_extras_test/image_stitch_test.py
Normal file
240
tests-unit/comfy_extras_test/image_stitch_test.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import torch
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# Mock nodes module to prevent CUDA initialization during import
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.MAX_RESOLUTION = 16384
|
||||
|
||||
with patch.dict('sys.modules', {'nodes': mock_nodes}):
|
||||
from comfy_extras.nodes_images import ImageStitch
|
||||
|
||||
|
||||
class TestImageStitch:
|
||||
|
||||
def create_test_image(self, batch_size=1, height=64, width=64, channels=3):
|
||||
"""Helper to create test images with specific dimensions"""
|
||||
return torch.rand(batch_size, height, width, channels)
|
||||
|
||||
def test_no_image2_passthrough(self):
|
||||
"""Test that when image2 is None, image1 is returned unchanged"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image()
|
||||
|
||||
result = node.stitch(image1, "right", True, 0, "white", image2=None)
|
||||
|
||||
assert len(result) == 1
|
||||
assert torch.equal(result[0], image1)
|
||||
|
||||
def test_basic_horizontal_stitch_right(self):
|
||||
"""Test basic horizontal stitching to the right"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=24)
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 32, 56, 3) # 32 + 24 width
|
||||
|
||||
def test_basic_horizontal_stitch_left(self):
|
||||
"""Test basic horizontal stitching to the left"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=24)
|
||||
|
||||
result = node.stitch(image1, "left", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 32, 56, 3) # 24 + 32 width
|
||||
|
||||
def test_basic_vertical_stitch_down(self):
|
||||
"""Test basic vertical stitching downward"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=24, width=32)
|
||||
|
||||
result = node.stitch(image1, "down", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 56, 32, 3) # 32 + 24 height
|
||||
|
||||
def test_basic_vertical_stitch_up(self):
|
||||
"""Test basic vertical stitching upward"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=24, width=32)
|
||||
|
||||
result = node.stitch(image1, "up", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 56, 32, 3) # 24 + 32 height
|
||||
|
||||
def test_size_matching_horizontal(self):
|
||||
"""Test size matching for horizontal concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=64, width=64)
|
||||
image2 = self.create_test_image(height=32, width=32) # Different aspect ratio
|
||||
|
||||
result = node.stitch(image1, "right", True, 0, "white", image2)
|
||||
|
||||
# image2 should be resized to match image1's height (64) with preserved aspect ratio
|
||||
expected_width = 64 + 64 # original + resized (32*64/32 = 64)
|
||||
assert result[0].shape == (1, 64, expected_width, 3)
|
||||
|
||||
def test_size_matching_vertical(self):
|
||||
"""Test size matching for vertical concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=64, width=64)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
result = node.stitch(image1, "down", True, 0, "white", image2)
|
||||
|
||||
# image2 should be resized to match image1's width (64) with preserved aspect ratio
|
||||
expected_height = 64 + 64 # original + resized (32*64/32 = 64)
|
||||
assert result[0].shape == (1, expected_height, 64, 3)
|
||||
|
||||
def test_padding_for_mismatched_heights_horizontal(self):
|
||||
"""Test padding when heights don't match in horizontal concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=64, width=32)
|
||||
image2 = self.create_test_image(height=48, width=24) # Shorter height
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Both images should be padded to height 64
|
||||
assert result[0].shape == (1, 64, 56, 3) # 32 + 24 width, max(64,48) height
|
||||
|
||||
def test_padding_for_mismatched_widths_vertical(self):
|
||||
"""Test padding when widths don't match in vertical concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=64)
|
||||
image2 = self.create_test_image(height=24, width=48) # Narrower width
|
||||
|
||||
result = node.stitch(image1, "down", False, 0, "white", image2)
|
||||
|
||||
# Both images should be padded to width 64
|
||||
assert result[0].shape == (1, 56, 64, 3) # 32 + 24 height, max(64,48) width
|
||||
|
||||
def test_spacing_horizontal(self):
|
||||
"""Test spacing addition in horizontal concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=24)
|
||||
spacing_width = 16
|
||||
|
||||
result = node.stitch(image1, "right", False, spacing_width, "white", image2)
|
||||
|
||||
# Expected width: 32 + 16 (spacing) + 24 = 72
|
||||
assert result[0].shape == (1, 32, 72, 3)
|
||||
|
||||
def test_spacing_vertical(self):
|
||||
"""Test spacing addition in vertical concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=24, width=32)
|
||||
spacing_width = 16
|
||||
|
||||
result = node.stitch(image1, "down", False, spacing_width, "white", image2)
|
||||
|
||||
# Expected height: 32 + 16 (spacing) + 24 = 72
|
||||
assert result[0].shape == (1, 72, 32, 3)
|
||||
|
||||
def test_spacing_color_values(self):
|
||||
"""Test that spacing colors are applied correctly"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
# Test white spacing
|
||||
result_white = node.stitch(image1, "right", False, 16, "white", image2)
|
||||
# Check that spacing region contains white values (close to 1.0)
|
||||
spacing_region = result_white[0][:, :, 32:48, :] # Middle 16 pixels
|
||||
assert torch.all(spacing_region >= 0.9) # Should be close to white
|
||||
|
||||
# Test black spacing
|
||||
result_black = node.stitch(image1, "right", False, 16, "black", image2)
|
||||
spacing_region = result_black[0][:, :, 32:48, :]
|
||||
assert torch.all(spacing_region <= 0.1) # Should be close to black
|
||||
|
||||
def test_odd_spacing_width_made_even(self):
|
||||
"""Test that odd spacing widths are made even"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
# Use odd spacing width
|
||||
result = node.stitch(image1, "right", False, 15, "white", image2)
|
||||
|
||||
# Should be made even (16), so total width = 32 + 16 + 32 = 80
|
||||
assert result[0].shape == (1, 32, 80, 3)
|
||||
|
||||
def test_batch_size_matching(self):
|
||||
"""Test that different batch sizes are handled correctly"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(batch_size=2, height=32, width=32)
|
||||
image2 = self.create_test_image(batch_size=1, height=32, width=32)
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Should match larger batch size
|
||||
assert result[0].shape == (2, 32, 64, 3)
|
||||
|
||||
def test_channel_matching_rgb_to_rgba(self):
|
||||
"""Test that channel differences are handled (RGB + alpha)"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(channels=3) # RGB
|
||||
image2 = self.create_test_image(channels=4) # RGBA
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Should have 4 channels (RGBA)
|
||||
assert result[0].shape[-1] == 4
|
||||
|
||||
def test_channel_matching_rgba_to_rgb(self):
|
||||
"""Test that channel differences are handled (RGBA + RGB)"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(channels=4) # RGBA
|
||||
image2 = self.create_test_image(channels=3) # RGB
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Should have 4 channels (RGBA)
|
||||
assert result[0].shape[-1] == 4
|
||||
|
||||
def test_all_color_options(self):
|
||||
"""Test all available color options"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
colors = ["white", "black", "red", "green", "blue"]
|
||||
|
||||
for color in colors:
|
||||
result = node.stitch(image1, "right", False, 16, color, image2)
|
||||
assert result[0].shape == (1, 32, 80, 3) # Basic shape check
|
||||
|
||||
def test_all_directions(self):
|
||||
"""Test all direction options"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
directions = ["right", "left", "up", "down"]
|
||||
|
||||
for direction in directions:
|
||||
result = node.stitch(image1, direction, False, 0, "white", image2)
|
||||
assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3)
|
||||
|
||||
def test_batch_size_channel_spacing_integration(self):
|
||||
"""Test integration of batch matching, channel matching, size matching, and spacings"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3)
|
||||
image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4)
|
||||
|
||||
result = node.stitch(image1, "right", True, 8, "red", image2)
|
||||
|
||||
# Should handle: batch matching, size matching, channel matching, spacing
|
||||
assert result[0].shape[0] == 2 # Batch size matched
|
||||
assert result[0].shape[-1] == 4 # Channels matched to max
|
||||
assert result[0].shape[1] == 64 # Height from image1 (size matching)
|
||||
# Width should be: 48 + 8 (spacing) + resized_image2_width
|
||||
expected_image2_width = int(64 * (32/32)) # Resized to height 64
|
||||
expected_total_width = 48 + 8 + expected_image2_width
|
||||
assert result[0].shape[2] == expected_total_width
|
||||
|
||||
19
utils/install_util.py
Normal file
19
utils/install_util.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# The path to the requirements.txt file
|
||||
requirements_path = Path(__file__).parents[1] / "requirements.txt"
|
||||
|
||||
|
||||
def get_missing_requirements_message():
|
||||
"""The warning message to display when a package is missing."""
|
||||
|
||||
extra = ""
|
||||
if sys.flags.no_user_site:
|
||||
extra = "-s "
|
||||
return f"""
|
||||
Please install the updated requirements.txt file by running:
|
||||
{sys.executable} {extra}-m pip install -r {requirements_path}
|
||||
|
||||
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem.
|
||||
""".strip()
|
||||
Reference in New Issue
Block a user