Compare commits
79 Commits
pysssss-mo
...
worksplit-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0336b0ace8 | ||
|
|
8ae25235ec | ||
|
|
9726eac475 | ||
|
|
272e8d42c1 | ||
|
|
6211d2be5a | ||
|
|
8be711715c | ||
|
|
b5cccf1325 | ||
|
|
2a54a904f4 | ||
|
|
ed6f92c975 | ||
|
|
adc66c0698 | ||
|
|
ccd5c01e5a | ||
|
|
2fa9affcc1 | ||
|
|
407a5a656f | ||
|
|
9ce9ff8ef8 | ||
|
|
63567c0ce8 | ||
|
|
a786ce5ead | ||
|
|
4879b47648 | ||
|
|
5ccec33c22 | ||
|
|
219d3cd0d0 | ||
|
|
c4ba399475 | ||
|
|
cc928a786d | ||
|
|
6e144b98c4 | ||
|
|
6dca17bd2d | ||
|
|
5080105c23 | ||
|
|
093914a247 | ||
|
|
605893d3cf | ||
|
|
048f4f0b3a | ||
|
|
d2504fb701 | ||
|
|
b03763bca6 | ||
|
|
476aa79b64 | ||
|
|
441cfd1a7a | ||
|
|
99a5c1068a | ||
|
|
02747cde7d | ||
|
|
0b3233b4e2 | ||
|
|
eda866bf51 | ||
|
|
e3298b84de | ||
|
|
c7feef9060 | ||
|
|
51af7fa1b4 | ||
|
|
46969c380a | ||
|
|
5db4277449 | ||
|
|
02a4d0ad7d | ||
|
|
ef137ac0b6 | ||
|
|
328d4f16a9 | ||
|
|
bdbcb85b8d | ||
|
|
6c9e94bae7 | ||
|
|
bfce723311 | ||
|
|
31f5458938 | ||
|
|
2145a202eb | ||
|
|
25818dc848 | ||
|
|
198953cd08 | ||
|
|
ec16ee2f39 | ||
|
|
d5088072fb | ||
|
|
8d4b50158e | ||
|
|
e88c6c03ff | ||
|
|
d3cf2b7b24 | ||
|
|
7448f02b7c | ||
|
|
871258aa72 | ||
|
|
66838ebd39 | ||
|
|
7333281698 | ||
|
|
3cd4c5cb0a | ||
|
|
11c6d56037 | ||
|
|
216fea15ee | ||
|
|
58bf8815c8 | ||
|
|
1b38f5bf57 | ||
|
|
2724ac4a60 | ||
|
|
f48f90e471 | ||
|
|
6463c39ce0 | ||
|
|
0a7e2ae787 | ||
|
|
03a97b604a | ||
|
|
4446c86052 | ||
|
|
8270ff312f | ||
|
|
db2d7ad9ba | ||
|
|
6620d86318 | ||
|
|
111fd0cadf | ||
|
|
776aa734e1 | ||
|
|
5a2ad032cb | ||
|
|
d44295ef71 | ||
|
|
bf21be066f | ||
|
|
72bbf49349 |
84
alembic.ini
84
alembic.ini
@@ -1,84 +0,0 @@
|
|||||||
# A generic, single database configuration.
|
|
||||||
|
|
||||||
[alembic]
|
|
||||||
# path to migration scripts
|
|
||||||
# Use forward slashes (/) also on windows to provide an os agnostic path
|
|
||||||
script_location = alembic_db
|
|
||||||
|
|
||||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
|
||||||
# Uncomment the line below if you want the files to be prepended with date and time
|
|
||||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
|
||||||
# for all available tokens
|
|
||||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
|
||||||
|
|
||||||
# sys.path path, will be prepended to sys.path if present.
|
|
||||||
# defaults to the current working directory.
|
|
||||||
prepend_sys_path = .
|
|
||||||
|
|
||||||
# timezone to use when rendering the date within the migration file
|
|
||||||
# as well as the filename.
|
|
||||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
|
||||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
|
||||||
# string value is passed to ZoneInfo()
|
|
||||||
# leave blank for localtime
|
|
||||||
# timezone =
|
|
||||||
|
|
||||||
# max length of characters to apply to the "slug" field
|
|
||||||
# truncate_slug_length = 40
|
|
||||||
|
|
||||||
# set to 'true' to run the environment during
|
|
||||||
# the 'revision' command, regardless of autogenerate
|
|
||||||
# revision_environment = false
|
|
||||||
|
|
||||||
# set to 'true' to allow .pyc and .pyo files without
|
|
||||||
# a source .py file to be detected as revisions in the
|
|
||||||
# versions/ directory
|
|
||||||
# sourceless = false
|
|
||||||
|
|
||||||
# version location specification; This defaults
|
|
||||||
# to alembic_db/versions. When using multiple version
|
|
||||||
# directories, initial revisions must be specified with --version-path.
|
|
||||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
|
||||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
|
|
||||||
|
|
||||||
# version path separator; As mentioned above, this is the character used to split
|
|
||||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
|
||||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
|
||||||
# Valid values for version_path_separator are:
|
|
||||||
#
|
|
||||||
# version_path_separator = :
|
|
||||||
# version_path_separator = ;
|
|
||||||
# version_path_separator = space
|
|
||||||
# version_path_separator = newline
|
|
||||||
#
|
|
||||||
# Use os.pathsep. Default configuration used for new projects.
|
|
||||||
version_path_separator = os
|
|
||||||
|
|
||||||
# set to 'true' to search source files recursively
|
|
||||||
# in each "version_locations" directory
|
|
||||||
# new in Alembic version 1.10
|
|
||||||
# recursive_version_locations = false
|
|
||||||
|
|
||||||
# the output encoding used when revision files
|
|
||||||
# are written from script.py.mako
|
|
||||||
# output_encoding = utf-8
|
|
||||||
|
|
||||||
sqlalchemy.url = sqlite:///user/comfyui.db
|
|
||||||
|
|
||||||
|
|
||||||
[post_write_hooks]
|
|
||||||
# post_write_hooks defines scripts or Python functions that are run
|
|
||||||
# on newly generated revision scripts. See the documentation for further
|
|
||||||
# detail and examples
|
|
||||||
|
|
||||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
|
||||||
# hooks = black
|
|
||||||
# black.type = console_scripts
|
|
||||||
# black.entrypoint = black
|
|
||||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
|
||||||
|
|
||||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
|
||||||
# hooks = ruff
|
|
||||||
# ruff.type = exec
|
|
||||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
|
||||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
## Generate new revision
|
|
||||||
|
|
||||||
1. Update models in `/app/database/models.py`
|
|
||||||
2. Run `alembic revision --autogenerate -m "{your message}"`
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
from sqlalchemy import engine_from_config
|
|
||||||
from sqlalchemy import pool
|
|
||||||
|
|
||||||
from alembic import context
|
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
|
||||||
# access to the values within the .ini file in use.
|
|
||||||
config = context.config
|
|
||||||
|
|
||||||
|
|
||||||
from app.database.models import Base
|
|
||||||
target_metadata = Base.metadata
|
|
||||||
|
|
||||||
# other values from the config, defined by the needs of env.py,
|
|
||||||
# can be acquired:
|
|
||||||
# my_important_option = config.get_main_option("my_important_option")
|
|
||||||
# ... etc.
|
|
||||||
|
|
||||||
|
|
||||||
def run_migrations_offline() -> None:
|
|
||||||
"""Run migrations in 'offline' mode.
|
|
||||||
|
|
||||||
This configures the context with just a URL
|
|
||||||
and not an Engine, though an Engine is acceptable
|
|
||||||
here as well. By skipping the Engine creation
|
|
||||||
we don't even need a DBAPI to be available.
|
|
||||||
|
|
||||||
Calls to context.execute() here emit the given string to the
|
|
||||||
script output.
|
|
||||||
|
|
||||||
"""
|
|
||||||
url = config.get_main_option("sqlalchemy.url")
|
|
||||||
context.configure(
|
|
||||||
url=url,
|
|
||||||
target_metadata=target_metadata,
|
|
||||||
literal_binds=True,
|
|
||||||
dialect_opts={"paramstyle": "named"},
|
|
||||||
)
|
|
||||||
|
|
||||||
with context.begin_transaction():
|
|
||||||
context.run_migrations()
|
|
||||||
|
|
||||||
|
|
||||||
def run_migrations_online() -> None:
|
|
||||||
"""Run migrations in 'online' mode.
|
|
||||||
|
|
||||||
In this scenario we need to create an Engine
|
|
||||||
and associate a connection with the context.
|
|
||||||
|
|
||||||
"""
|
|
||||||
connectable = engine_from_config(
|
|
||||||
config.get_section(config.config_ini_section, {}),
|
|
||||||
prefix="sqlalchemy.",
|
|
||||||
poolclass=pool.NullPool,
|
|
||||||
)
|
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
|
||||||
context.configure(
|
|
||||||
connection=connection, target_metadata=target_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
with context.begin_transaction():
|
|
||||||
context.run_migrations()
|
|
||||||
|
|
||||||
|
|
||||||
if context.is_offline_mode():
|
|
||||||
run_migrations_offline()
|
|
||||||
else:
|
|
||||||
run_migrations_online()
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
"""${message}
|
|
||||||
|
|
||||||
Revision ID: ${up_revision}
|
|
||||||
Revises: ${down_revision | comma,n}
|
|
||||||
Create Date: ${create_date}
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
${imports if imports else ""}
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = ${repr(up_revision)}
|
|
||||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
|
||||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
${upgrades if upgrades else "pass"}
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
${downgrades if downgrades else "pass"}
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
"""init
|
|
||||||
|
|
||||||
Revision ID: e9c714da8d57
|
|
||||||
Revises:
|
|
||||||
Create Date: 2025-05-30 20:14:33.772039
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = 'e9c714da8d57'
|
|
||||||
down_revision: Union[str, None] = None
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
op.create_table('model',
|
|
||||||
sa.Column('type', sa.Text(), nullable=False),
|
|
||||||
sa.Column('path', sa.Text(), nullable=False),
|
|
||||||
sa.Column('file_name', sa.Text(), nullable=True),
|
|
||||||
sa.Column('file_size', sa.Integer(), nullable=True),
|
|
||||||
sa.Column('hash', sa.Text(), nullable=True),
|
|
||||||
sa.Column('hash_algorithm', sa.Text(), nullable=True),
|
|
||||||
sa.Column('source_url', sa.Text(), nullable=True),
|
|
||||||
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
|
||||||
sa.PrimaryKeyConstraint('type', 'path')
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_table('model')
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
@@ -1,112 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
from app.logger import log_startup_warning
|
|
||||||
from utils.install_util import get_missing_requirements_message
|
|
||||||
from comfy.cli_args import args
|
|
||||||
|
|
||||||
_DB_AVAILABLE = False
|
|
||||||
Session = None
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from alembic import command
|
|
||||||
from alembic.config import Config
|
|
||||||
from alembic.runtime.migration import MigrationContext
|
|
||||||
from alembic.script import ScriptDirectory
|
|
||||||
from sqlalchemy import create_engine
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
_DB_AVAILABLE = True
|
|
||||||
except ImportError as e:
|
|
||||||
log_startup_warning(
|
|
||||||
f"""
|
|
||||||
------------------------------------------------------------------------
|
|
||||||
Error importing dependencies: {e}
|
|
||||||
|
|
||||||
{get_missing_requirements_message()}
|
|
||||||
|
|
||||||
This error is happening because ComfyUI now uses a local sqlite database.
|
|
||||||
------------------------------------------------------------------------
|
|
||||||
""".strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def dependencies_available():
|
|
||||||
"""
|
|
||||||
Temporary function to check if the dependencies are available
|
|
||||||
"""
|
|
||||||
return _DB_AVAILABLE
|
|
||||||
|
|
||||||
|
|
||||||
def can_create_session():
|
|
||||||
"""
|
|
||||||
Temporary function to check if the database is available to create a session
|
|
||||||
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
|
|
||||||
"""
|
|
||||||
return dependencies_available() and Session is not None
|
|
||||||
|
|
||||||
|
|
||||||
def get_alembic_config():
|
|
||||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
|
||||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
|
||||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
|
||||||
|
|
||||||
config = Config(config_path)
|
|
||||||
config.set_main_option("script_location", scripts_path)
|
|
||||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def get_db_path():
|
|
||||||
url = args.database_url
|
|
||||||
if url.startswith("sqlite:///"):
|
|
||||||
return url.split("///")[1]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
|
||||||
|
|
||||||
|
|
||||||
def init_db():
|
|
||||||
db_url = args.database_url
|
|
||||||
logging.debug(f"Database URL: {db_url}")
|
|
||||||
db_path = get_db_path()
|
|
||||||
db_exists = os.path.exists(db_path)
|
|
||||||
|
|
||||||
config = get_alembic_config()
|
|
||||||
|
|
||||||
# Check if we need to upgrade
|
|
||||||
engine = create_engine(db_url)
|
|
||||||
conn = engine.connect()
|
|
||||||
|
|
||||||
context = MigrationContext.configure(conn)
|
|
||||||
current_rev = context.get_current_revision()
|
|
||||||
|
|
||||||
script = ScriptDirectory.from_config(config)
|
|
||||||
target_rev = script.get_current_head()
|
|
||||||
|
|
||||||
if current_rev != target_rev:
|
|
||||||
# Backup the database pre upgrade
|
|
||||||
backup_path = db_path + ".bkp"
|
|
||||||
if db_exists:
|
|
||||||
shutil.copy(db_path, backup_path)
|
|
||||||
else:
|
|
||||||
backup_path = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
command.upgrade(config, target_rev)
|
|
||||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
|
||||||
except Exception as e:
|
|
||||||
if backup_path:
|
|
||||||
# Restore the database from backup if upgrade fails
|
|
||||||
shutil.copy(backup_path, db_path)
|
|
||||||
os.remove(backup_path)
|
|
||||||
logging.error(f"Error upgrading database: {e}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
global Session
|
|
||||||
Session = sessionmaker(bind=engine)
|
|
||||||
|
|
||||||
|
|
||||||
def create_session():
|
|
||||||
return Session()
|
|
||||||
@@ -1,59 +0,0 @@
|
|||||||
from sqlalchemy import (
|
|
||||||
Column,
|
|
||||||
Integer,
|
|
||||||
Text,
|
|
||||||
DateTime,
|
|
||||||
)
|
|
||||||
from sqlalchemy.orm import declarative_base
|
|
||||||
from sqlalchemy.sql import func
|
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
def to_dict(obj):
|
|
||||||
fields = obj.__table__.columns.keys()
|
|
||||||
return {
|
|
||||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
|
||||||
for field in fields
|
|
||||||
if (val := getattr(obj, field))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class Model(Base):
|
|
||||||
"""
|
|
||||||
sqlalchemy model representing a model file in the system.
|
|
||||||
|
|
||||||
This class defines the database schema for storing information about model files,
|
|
||||||
including their type, path, hash, and when they were added to the system.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
type (Text): The type of the model, this is the name of the folder in the models folder (primary key)
|
|
||||||
path (Text): The file path of the model relative to the type folder (primary key)
|
|
||||||
file_name (Text): The name of the model file
|
|
||||||
file_size (Integer): The size of the model file in bytes
|
|
||||||
hash (Text): A hash of the model file
|
|
||||||
hash_algorithm (Text): The algorithm used to generate the hash
|
|
||||||
source_url (Text): The URL of the model file
|
|
||||||
date_added (DateTime): Timestamp of when the model was added to the system
|
|
||||||
"""
|
|
||||||
|
|
||||||
__tablename__ = "model"
|
|
||||||
|
|
||||||
type = Column(Text, primary_key=True)
|
|
||||||
path = Column(Text, primary_key=True)
|
|
||||||
file_name = Column(Text)
|
|
||||||
file_size = Column(Integer)
|
|
||||||
hash = Column(Text)
|
|
||||||
hash_algorithm = Column(Text)
|
|
||||||
source_url = Column(Text)
|
|
||||||
date_added = Column(DateTime, server_default=func.now())
|
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
"""
|
|
||||||
Convert the model instance to a dictionary representation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: A dictionary containing the attributes of the model
|
|
||||||
"""
|
|
||||||
dict = to_dict(self)
|
|
||||||
return dict
|
|
||||||
@@ -16,15 +16,26 @@ from importlib.metadata import version
|
|||||||
import requests
|
import requests
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from utils.install_util import get_missing_requirements_message, requirements_path
|
|
||||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||||
import app.logger
|
import app.logger
|
||||||
|
|
||||||
|
# The path to the requirements.txt file
|
||||||
|
req_path = Path(__file__).parents[1] / "requirements.txt"
|
||||||
|
|
||||||
|
|
||||||
def frontend_install_warning_message():
|
def frontend_install_warning_message():
|
||||||
|
"""The warning message to display when the frontend version is not up to date."""
|
||||||
|
|
||||||
|
extra = ""
|
||||||
|
if sys.flags.no_user_site:
|
||||||
|
extra = "-s "
|
||||||
return f"""
|
return f"""
|
||||||
{get_missing_requirements_message()}
|
Please install the updated requirements.txt file by running:
|
||||||
|
{sys.executable} {extra}-m pip install -r {req_path}
|
||||||
|
|
||||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||||
|
|
||||||
|
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
@@ -37,7 +48,7 @@ def check_frontend_version():
|
|||||||
try:
|
try:
|
||||||
frontend_version_str = version("comfyui-frontend-package")
|
frontend_version_str = version("comfyui-frontend-package")
|
||||||
frontend_version = parse_version(frontend_version_str)
|
frontend_version = parse_version(frontend_version_str)
|
||||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
with open(req_path, "r", encoding="utf-8") as f:
|
||||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||||
if frontend_version < required_frontend:
|
if frontend_version < required_frontend:
|
||||||
app.logger.log_startup_warning(
|
app.logger.log_startup_warning(
|
||||||
@@ -151,30 +162,10 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
class FrontendManager:
|
class FrontendManager:
|
||||||
"""
|
|
||||||
A class to manage ComfyUI frontend versions and installations.
|
|
||||||
|
|
||||||
This class handles the initialization and management of different frontend versions,
|
|
||||||
including the default frontend from the pip package and custom frontend versions
|
|
||||||
from GitHub repositories.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored.
|
|
||||||
"""
|
|
||||||
|
|
||||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_frontend_path(cls) -> str:
|
def default_frontend_path(cls) -> str:
|
||||||
"""
|
|
||||||
Get the path to the default frontend installation from the pip package.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The path to the default frontend static files.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
SystemExit: If the comfyui-frontend-package is not installed.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
import comfyui_frontend_package
|
import comfyui_frontend_package
|
||||||
|
|
||||||
@@ -195,15 +186,6 @@ comfyui-frontend-package is not installed.
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def templates_path(cls) -> str:
|
def templates_path(cls) -> str:
|
||||||
"""
|
|
||||||
Get the path to the workflow templates.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The path to the workflow templates directory.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
SystemExit: If the comfyui-workflow-templates package is not installed.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
import comfyui_workflow_templates
|
import comfyui_workflow_templates
|
||||||
|
|
||||||
@@ -239,16 +221,11 @@ comfyui-workflow-templates is not installed.
|
|||||||
@classmethod
|
@classmethod
|
||||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
Parse a version string into its components.
|
|
||||||
|
|
||||||
The version string should be in the format: 'owner/repo@version'
|
|
||||||
where version can be either a semantic version (v1.2.3) or 'latest'.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
value (str): The version string to parse.
|
value (str): The version string to parse.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[str, str, str]: A tuple containing (owner, repo, version).
|
tuple[str, str]: A tuple containing provider name and version.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
argparse.ArgumentTypeError: If the version string is invalid.
|
argparse.ArgumentTypeError: If the version string is invalid.
|
||||||
@@ -265,22 +242,18 @@ comfyui-workflow-templates is not installed.
|
|||||||
cls, version_string: str, provider: Optional[FrontEndProvider] = None
|
cls, version_string: str, provider: Optional[FrontEndProvider] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Initialize a frontend version without error handling.
|
Initializes the frontend for the specified version.
|
||||||
|
|
||||||
This method attempts to initialize a specific frontend version, either from
|
|
||||||
the default pip package or from a custom GitHub repository. It will download
|
|
||||||
and extract the frontend files if necessary.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
version_string (str): The version string specifying which frontend to use.
|
version_string (str): The version string.
|
||||||
provider (FrontEndProvider, optional): The provider to use for custom frontends.
|
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The path to the initialized frontend.
|
str: The path to the initialized frontend.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If there is an error during initialization (e.g., network timeout,
|
Exception: If there is an error during the initialization process.
|
||||||
invalid URL, or missing assets).
|
main error source might be request timeout or invalid URL.
|
||||||
"""
|
"""
|
||||||
if version_string == DEFAULT_VERSION_STRING:
|
if version_string == DEFAULT_VERSION_STRING:
|
||||||
check_frontend_version()
|
check_frontend_version()
|
||||||
@@ -332,17 +305,13 @@ comfyui-workflow-templates is not installed.
|
|||||||
@classmethod
|
@classmethod
|
||||||
def init_frontend(cls, version_string: str) -> str:
|
def init_frontend(cls, version_string: str) -> str:
|
||||||
"""
|
"""
|
||||||
Initialize a frontend version with error handling.
|
Initializes the frontend with the specified version string.
|
||||||
|
|
||||||
This is the main method to initialize a frontend version. It wraps init_frontend_unsafe
|
|
||||||
with error handling, falling back to the default frontend if initialization fails.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
version_string (str): The version string specifying which frontend to use.
|
version_string (str): The version string to initialize the frontend with.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The path to the initialized frontend. If initialization fails,
|
str: The path of the initialized frontend.
|
||||||
returns the path to the default frontend.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return cls.init_frontend_unsafe(version_string)
|
return cls.init_frontend_unsafe(version_string)
|
||||||
|
|||||||
@@ -1,331 +0,0 @@
|
|||||||
import os
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from tqdm import tqdm
|
|
||||||
from folder_paths import get_relative_path, get_full_path
|
|
||||||
from app.database.db import create_session, dependencies_available, can_create_session
|
|
||||||
import blake3
|
|
||||||
import comfy.utils
|
|
||||||
|
|
||||||
|
|
||||||
if dependencies_available():
|
|
||||||
from app.database.models import Model
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProcessor:
|
|
||||||
def _validate_path(self, model_path):
|
|
||||||
try:
|
|
||||||
if not self._file_exists(model_path):
|
|
||||||
logging.error(f"Model file not found: {model_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
result = get_relative_path(model_path)
|
|
||||||
if not result:
|
|
||||||
logging.error(
|
|
||||||
f"Model file not in a recognized model directory: {model_path}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error validating model path {model_path}: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _file_exists(self, path):
|
|
||||||
"""Check if a file exists."""
|
|
||||||
return os.path.exists(path)
|
|
||||||
|
|
||||||
def _get_file_size(self, path):
|
|
||||||
"""Get file size."""
|
|
||||||
return os.path.getsize(path)
|
|
||||||
|
|
||||||
def _get_hasher(self):
|
|
||||||
return blake3.blake3()
|
|
||||||
|
|
||||||
def _hash_file(self, model_path):
|
|
||||||
try:
|
|
||||||
hasher = self._get_hasher()
|
|
||||||
with open(model_path, "rb", buffering=0) as f:
|
|
||||||
b = bytearray(128 * 1024)
|
|
||||||
mv = memoryview(b)
|
|
||||||
while n := f.readinto(mv):
|
|
||||||
hasher.update(mv[:n])
|
|
||||||
return hasher.hexdigest()
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error hashing file {model_path}: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _get_existing_model(self, session, model_type, model_relative_path):
|
|
||||||
return (
|
|
||||||
session.query(Model)
|
|
||||||
.filter(Model.type == model_type)
|
|
||||||
.filter(Model.path == model_relative_path)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _ensure_source_url(self, session, model, source_url):
|
|
||||||
if model.source_url is None:
|
|
||||||
model.source_url = source_url
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
def _update_database(
|
|
||||||
self,
|
|
||||||
session,
|
|
||||||
model_type,
|
|
||||||
model_path,
|
|
||||||
model_relative_path,
|
|
||||||
model_hash,
|
|
||||||
model,
|
|
||||||
source_url,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
if not model:
|
|
||||||
model = self._get_existing_model(
|
|
||||||
session, model_type, model_relative_path
|
|
||||||
)
|
|
||||||
|
|
||||||
if not model:
|
|
||||||
model = Model(
|
|
||||||
path=model_relative_path,
|
|
||||||
type=model_type,
|
|
||||||
file_name=os.path.basename(model_path),
|
|
||||||
)
|
|
||||||
session.add(model)
|
|
||||||
|
|
||||||
model.file_size = self._get_file_size(model_path)
|
|
||||||
model.hash = model_hash
|
|
||||||
if model_hash:
|
|
||||||
model.hash_algorithm = "blake3"
|
|
||||||
model.source_url = source_url
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
return model
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(
|
|
||||||
f"Error updating database for {model_relative_path}: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_file(self, model_path, source_url=None, model_hash=None):
|
|
||||||
"""
|
|
||||||
Process a model file and update the database with metadata.
|
|
||||||
If the file already exists and matches the database, it will not be processed again.
|
|
||||||
Returns the model object or if an error occurs, returns None.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not can_create_session():
|
|
||||||
return
|
|
||||||
|
|
||||||
result = self._validate_path(model_path)
|
|
||||||
if not result:
|
|
||||||
return
|
|
||||||
model_type, model_relative_path = result
|
|
||||||
|
|
||||||
with create_session() as session:
|
|
||||||
session.expire_on_commit = False
|
|
||||||
|
|
||||||
existing_model = self._get_existing_model(
|
|
||||||
session, model_type, model_relative_path
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
existing_model
|
|
||||||
and existing_model.hash
|
|
||||||
and existing_model.file_size == self._get_file_size(model_path)
|
|
||||||
):
|
|
||||||
# File exists with hash and same size, no need to process
|
|
||||||
self._ensure_source_url(session, existing_model, source_url)
|
|
||||||
return existing_model
|
|
||||||
|
|
||||||
if model_hash:
|
|
||||||
model_hash = model_hash.lower()
|
|
||||||
logging.info(f"Using provided hash: {model_hash}")
|
|
||||||
else:
|
|
||||||
start_time = time.time()
|
|
||||||
logging.info(f"Hashing model {model_relative_path}")
|
|
||||||
model_hash = self._hash_file(model_path)
|
|
||||||
if not model_hash:
|
|
||||||
return
|
|
||||||
logging.info(
|
|
||||||
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._update_database(
|
|
||||||
session,
|
|
||||||
model_type,
|
|
||||||
model_path,
|
|
||||||
model_relative_path,
|
|
||||||
model_hash,
|
|
||||||
existing_model,
|
|
||||||
source_url,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error processing model file {model_path}: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def retrieve_model_by_hash(self, model_hash, model_type=None, session=None):
|
|
||||||
"""
|
|
||||||
Retrieve a model file from the database by hash and optionally by model type.
|
|
||||||
Returns the model object or None if the model doesnt exist or an error occurs.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not can_create_session():
|
|
||||||
return
|
|
||||||
|
|
||||||
dispose_session = False
|
|
||||||
|
|
||||||
if session is None:
|
|
||||||
session = create_session()
|
|
||||||
dispose_session = True
|
|
||||||
|
|
||||||
model = session.query(Model).filter(Model.hash == model_hash)
|
|
||||||
if model_type is not None:
|
|
||||||
model = model.filter(Model.type == model_type)
|
|
||||||
return model.first()
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}")
|
|
||||||
return None
|
|
||||||
finally:
|
|
||||||
if dispose_session:
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
def retrieve_hash(self, model_path, model_type=None):
|
|
||||||
"""
|
|
||||||
Retrieve the hash of a model file from the database.
|
|
||||||
Returns the hash or None if the model doesnt exist or an error occurs.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not can_create_session():
|
|
||||||
return
|
|
||||||
|
|
||||||
if model_type is not None:
|
|
||||||
result = self._validate_path(model_path)
|
|
||||||
if not result:
|
|
||||||
return None
|
|
||||||
model_type, model_relative_path = result
|
|
||||||
|
|
||||||
with create_session() as session:
|
|
||||||
model = self._get_existing_model(
|
|
||||||
session, model_type, model_relative_path
|
|
||||||
)
|
|
||||||
if model and model.hash:
|
|
||||||
return model.hash
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error retrieving hash for {model_path}: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _validate_file_extension(self, file_name):
|
|
||||||
"""Validate that the file extension is supported."""
|
|
||||||
extension = os.path.splitext(file_name)[1]
|
|
||||||
if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"):
|
|
||||||
raise ValueError(f"Unsupported unsafe file for download: {file_name}")
|
|
||||||
|
|
||||||
def _check_existing_file(self, model_type, file_name, expected_hash):
|
|
||||||
"""Check if file exists and has correct hash."""
|
|
||||||
destination_path = get_full_path(model_type, file_name, allow_missing=True)
|
|
||||||
if self._file_exists(destination_path):
|
|
||||||
model = self.process_file(destination_path)
|
|
||||||
if model and (expected_hash is None or model.hash == expected_hash):
|
|
||||||
logging.debug(
|
|
||||||
f"File {destination_path} already exists in the database and has the correct hash or no hash was provided."
|
|
||||||
)
|
|
||||||
return destination_path
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again."
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _check_existing_file_by_hash(self, hash, type, url):
|
|
||||||
"""Check if a file with the given hash exists in the database and on disk."""
|
|
||||||
hash = hash.lower()
|
|
||||||
with create_session() as session:
|
|
||||||
model = self.retrieve_model_by_hash(hash, type, session)
|
|
||||||
if model:
|
|
||||||
existing_path = get_full_path(type, model.path)
|
|
||||||
if existing_path:
|
|
||||||
logging.debug(
|
|
||||||
f"File {model.path} already exists in the database at {existing_path}"
|
|
||||||
)
|
|
||||||
self._ensure_source_url(session, model, url)
|
|
||||||
return existing_path
|
|
||||||
else:
|
|
||||||
logging.debug(
|
|
||||||
f"File {model.path} exists in the database but not on disk"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _download_file(self, url, destination_path, hasher):
|
|
||||||
"""Download a file and update the hasher with its contents."""
|
|
||||||
response = requests.get(url, stream=True)
|
|
||||||
logging.info(f"Downloading {url} to {destination_path}")
|
|
||||||
|
|
||||||
with open(destination_path, "wb") as f:
|
|
||||||
total_size = int(response.headers.get("content-length", 0))
|
|
||||||
if total_size > 0:
|
|
||||||
pbar = comfy.utils.ProgressBar(total_size)
|
|
||||||
else:
|
|
||||||
pbar = None
|
|
||||||
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
|
|
||||||
for chunk in response.iter_content(chunk_size=128 * 1024):
|
|
||||||
if chunk:
|
|
||||||
f.write(chunk)
|
|
||||||
hasher.update(chunk)
|
|
||||||
progress_bar.update(len(chunk))
|
|
||||||
if pbar:
|
|
||||||
pbar.update(len(chunk))
|
|
||||||
|
|
||||||
def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path):
|
|
||||||
"""Verify that the downloaded file has the expected hash."""
|
|
||||||
if expected_hash is not None and calculated_hash != expected_hash:
|
|
||||||
self._remove_file(destination_path)
|
|
||||||
raise ValueError(
|
|
||||||
f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _remove_file(self, file_path):
|
|
||||||
"""Remove a file from disk."""
|
|
||||||
os.remove(file_path)
|
|
||||||
|
|
||||||
def ensure_downloaded(self, type, url, desired_file_name, hash=None):
|
|
||||||
"""
|
|
||||||
Ensure a model file is downloaded and has the correct hash.
|
|
||||||
Returns the path to the downloaded file.
|
|
||||||
"""
|
|
||||||
logging.debug(
|
|
||||||
f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate file extension
|
|
||||||
self._validate_file_extension(desired_file_name)
|
|
||||||
|
|
||||||
# Check if file exists with correct hash
|
|
||||||
if hash:
|
|
||||||
existing_path = self._check_existing_file_by_hash(hash, type, url)
|
|
||||||
if existing_path:
|
|
||||||
return existing_path
|
|
||||||
|
|
||||||
# Check if file exists locally
|
|
||||||
destination_path = get_full_path(type, desired_file_name, allow_missing=True)
|
|
||||||
existing_path = self._check_existing_file(type, desired_file_name, hash)
|
|
||||||
if existing_path:
|
|
||||||
return existing_path
|
|
||||||
|
|
||||||
# Download the file
|
|
||||||
hasher = self._get_hasher()
|
|
||||||
self._download_file(url, destination_path, hasher)
|
|
||||||
|
|
||||||
# Verify hash
|
|
||||||
calculated_hash = hasher.hexdigest()
|
|
||||||
self._verify_downloaded_hash(calculated_hash, hash, destination_path)
|
|
||||||
|
|
||||||
# Update database
|
|
||||||
self.process_file(destination_path, url, calculated_hash)
|
|
||||||
|
|
||||||
# TODO: Notify frontend to reload models
|
|
||||||
|
|
||||||
return destination_path
|
|
||||||
|
|
||||||
|
|
||||||
model_processor = ModelProcessor()
|
|
||||||
@@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
|
|||||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use.")
|
||||||
cm_group = parser.add_mutually_exclusive_group()
|
cm_group = parser.add_mutually_exclusive_group()
|
||||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
||||||
@@ -203,12 +203,6 @@ parser.add_argument(
|
|||||||
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
||||||
)
|
)
|
||||||
|
|
||||||
database_default_path = os.path.abspath(
|
|
||||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
|
||||||
)
|
|
||||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
|
||||||
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
|
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -15,13 +15,14 @@
|
|||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
import copy
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_detection
|
import comfy.model_detection
|
||||||
@@ -36,7 +37,7 @@ import comfy.cldm.mmdit
|
|||||||
import comfy.ldm.hydit.controlnet
|
import comfy.ldm.hydit.controlnet
|
||||||
import comfy.ldm.flux.controlnet
|
import comfy.ldm.flux.controlnet
|
||||||
import comfy.cldm.dit_embedder
|
import comfy.cldm.dit_embedder
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Union
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.hooks import HookGroup
|
from comfy.hooks import HookGroup
|
||||||
|
|
||||||
@@ -63,6 +64,18 @@ class StrengthType(Enum):
|
|||||||
CONSTANT = 1
|
CONSTANT = 1
|
||||||
LINEAR_UP = 2
|
LINEAR_UP = 2
|
||||||
|
|
||||||
|
class ControlIsolation:
|
||||||
|
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
|
||||||
|
def __init__(self, control: ControlBase):
|
||||||
|
self.control = control
|
||||||
|
self.orig_previous_controlnet = control.previous_controlnet
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.control.previous_controlnet = None
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self.control.previous_controlnet = self.orig_previous_controlnet
|
||||||
|
|
||||||
class ControlBase:
|
class ControlBase:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
@@ -76,7 +89,7 @@ class ControlBase:
|
|||||||
self.compression_ratio = 8
|
self.compression_ratio = 8
|
||||||
self.upscale_algorithm = 'nearest-exact'
|
self.upscale_algorithm = 'nearest-exact'
|
||||||
self.extra_args = {}
|
self.extra_args = {}
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet: Union[ControlBase, None] = None
|
||||||
self.extra_conds = []
|
self.extra_conds = []
|
||||||
self.strength_type = StrengthType.CONSTANT
|
self.strength_type = StrengthType.CONSTANT
|
||||||
self.concat_mask = False
|
self.concat_mask = False
|
||||||
@@ -84,6 +97,7 @@ class ControlBase:
|
|||||||
self.extra_concat = None
|
self.extra_concat = None
|
||||||
self.extra_hooks: HookGroup = None
|
self.extra_hooks: HookGroup = None
|
||||||
self.preprocess_image = lambda a: a
|
self.preprocess_image = lambda a: a
|
||||||
|
self.multigpu_clones: dict[torch.device, ControlBase] = {}
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
@@ -110,17 +124,38 @@ class ControlBase:
|
|||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
self.previous_controlnet.cleanup()
|
self.previous_controlnet.cleanup()
|
||||||
|
for device_cnet in self.multigpu_clones.values():
|
||||||
|
with ControlIsolation(device_cnet):
|
||||||
|
device_cnet.cleanup()
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.extra_concat = None
|
self.extra_concat = None
|
||||||
self.timestep_range = None
|
self.timestep_range = None
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
out = []
|
out = []
|
||||||
|
for device_cnet in self.multigpu_clones.values():
|
||||||
|
out += device_cnet.get_models_only_self()
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
out += self.previous_controlnet.get_models()
|
out += self.previous_controlnet.get_models()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def get_models_only_self(self):
|
||||||
|
'Calls get_models, but temporarily sets previous_controlnet to None.'
|
||||||
|
with ControlIsolation(self):
|
||||||
|
return self.get_models()
|
||||||
|
|
||||||
|
def get_instance_for_device(self, device):
|
||||||
|
'Returns instance of this Control object intended for selected device.'
|
||||||
|
return self.multigpu_clones.get(device, self)
|
||||||
|
|
||||||
|
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||||
|
'''
|
||||||
|
Create deep clone of Control object where model(s) is set to other devices.
|
||||||
|
|
||||||
|
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
|
||||||
|
'''
|
||||||
|
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
|
||||||
|
|
||||||
def get_extra_hooks(self):
|
def get_extra_hooks(self):
|
||||||
out = []
|
out = []
|
||||||
if self.extra_hooks is not None:
|
if self.extra_hooks is not None:
|
||||||
@@ -129,7 +164,7 @@ class ControlBase:
|
|||||||
out += self.previous_controlnet.get_extra_hooks()
|
out += self.previous_controlnet.get_extra_hooks()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def copy_to(self, c):
|
def copy_to(self, c: ControlBase):
|
||||||
c.cond_hint_original = self.cond_hint_original
|
c.cond_hint_original = self.cond_hint_original
|
||||||
c.strength = self.strength
|
c.strength = self.strength
|
||||||
c.timestep_percent_range = self.timestep_percent_range
|
c.timestep_percent_range = self.timestep_percent_range
|
||||||
@@ -280,6 +315,14 @@ class ControlNet(ControlBase):
|
|||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||||
|
c = self.copy()
|
||||||
|
c.control_model = copy.deepcopy(c.control_model)
|
||||||
|
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||||
|
if autoregister:
|
||||||
|
self.multigpu_clones[load_device] = c
|
||||||
|
return c
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
out = super().get_models()
|
out = super().get_models()
|
||||||
out.append(self.control_model_wrapped)
|
out.append(self.control_model_wrapped)
|
||||||
@@ -805,6 +848,14 @@ class T2IAdapter(ControlBase):
|
|||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||||
|
c = self.copy()
|
||||||
|
c.t2i_model = copy.deepcopy(c.t2i_model)
|
||||||
|
c.device = load_device
|
||||||
|
if autoregister:
|
||||||
|
self.multigpu_clones[load_device] = c
|
||||||
|
return c
|
||||||
|
|
||||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||||
compression_ratio = 8
|
compression_ratio = 8
|
||||||
upscale_algorithm = 'nearest-exact'
|
upscale_algorithm = 'nearest-exact'
|
||||||
|
|||||||
@@ -102,13 +102,6 @@ def model_sampling(model_config, model_type):
|
|||||||
return ModelSampling(model_config)
|
return ModelSampling(model_config)
|
||||||
|
|
||||||
|
|
||||||
def convert_tensor(extra, dtype):
|
|
||||||
if hasattr(extra, "dtype"):
|
|
||||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
|
||||||
extra = extra.to(dtype)
|
|
||||||
return extra
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -172,13 +165,13 @@ class BaseModel(torch.nn.Module):
|
|||||||
extra_conds = {}
|
extra_conds = {}
|
||||||
for o in kwargs:
|
for o in kwargs:
|
||||||
extra = kwargs[o]
|
extra = kwargs[o]
|
||||||
|
|
||||||
if hasattr(extra, "dtype"):
|
if hasattr(extra, "dtype"):
|
||||||
extra = convert_tensor(extra, dtype)
|
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||||
elif isinstance(extra, list):
|
extra = extra.to(dtype)
|
||||||
|
if isinstance(extra, list):
|
||||||
ex = []
|
ex = []
|
||||||
for ext in extra:
|
for ext in extra:
|
||||||
ex.append(convert_tensor(ext, dtype))
|
ex.append(ext.to(dtype))
|
||||||
extra = ex
|
extra = ex
|
||||||
extra_conds[o] = extra
|
extra_conds[o] = extra
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import logging
|
import logging
|
||||||
@@ -26,6 +27,10 @@ import platform
|
|||||||
import weakref
|
import weakref
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||||
@@ -171,6 +176,25 @@ def get_torch_device():
|
|||||||
else:
|
else:
|
||||||
return torch.device(torch.cuda.current_device())
|
return torch.device(torch.cuda.current_device())
|
||||||
|
|
||||||
|
def get_all_torch_devices(exclude_current=False):
|
||||||
|
global cpu_state
|
||||||
|
devices = []
|
||||||
|
if cpu_state == CPUState.GPU:
|
||||||
|
if is_nvidia():
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
devices.append(torch.device(i))
|
||||||
|
elif is_intel_xpu():
|
||||||
|
for i in range(torch.xpu.device_count()):
|
||||||
|
devices.append(torch.device(i))
|
||||||
|
elif is_ascend_npu():
|
||||||
|
for i in range(torch.npu.device_count()):
|
||||||
|
devices.append(torch.device(i))
|
||||||
|
else:
|
||||||
|
devices.append(get_torch_device())
|
||||||
|
if exclude_current:
|
||||||
|
devices.remove(get_torch_device())
|
||||||
|
return devices
|
||||||
|
|
||||||
def get_total_memory(dev=None, torch_total_too=False):
|
def get_total_memory(dev=None, torch_total_too=False):
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
if dev is None:
|
if dev is None:
|
||||||
@@ -387,9 +411,13 @@ try:
|
|||||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||||
except:
|
except:
|
||||||
logging.warning("Could not pick default device.")
|
logging.warning("Could not pick default device.")
|
||||||
|
try:
|
||||||
|
for device in get_all_torch_devices(exclude_current=True):
|
||||||
|
logging.info("Device: {}".format(get_torch_device_name(device)))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
current_loaded_models: list[LoadedModel] = []
|
||||||
current_loaded_models = []
|
|
||||||
|
|
||||||
def module_size(module):
|
def module_size(module):
|
||||||
module_mem = 0
|
module_mem = 0
|
||||||
@@ -400,7 +428,7 @@ def module_size(module):
|
|||||||
return module_mem
|
return module_mem
|
||||||
|
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
def __init__(self, model):
|
def __init__(self, model: ModelPatcher):
|
||||||
self._set_model(model)
|
self._set_model(model)
|
||||||
self.device = model.load_device
|
self.device = model.load_device
|
||||||
self.real_model = None
|
self.real_model = None
|
||||||
@@ -408,7 +436,7 @@ class LoadedModel:
|
|||||||
self.model_finalizer = None
|
self.model_finalizer = None
|
||||||
self._patcher_finalizer = None
|
self._patcher_finalizer = None
|
||||||
|
|
||||||
def _set_model(self, model):
|
def _set_model(self, model: ModelPatcher):
|
||||||
self._model = weakref.ref(model)
|
self._model = weakref.ref(model)
|
||||||
if model.parent is not None:
|
if model.parent is not None:
|
||||||
self._parent_model = weakref.ref(model.parent)
|
self._parent_model = weakref.ref(model.parent)
|
||||||
@@ -1300,8 +1328,34 @@ def soft_empty_cache(force=False):
|
|||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
def unload_all_models():
|
def unload_all_models():
|
||||||
free_memory(1e30, get_torch_device())
|
for device in get_all_torch_devices():
|
||||||
|
free_memory(1e30, device)
|
||||||
|
|
||||||
|
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
|
||||||
|
'Unload only model and its clones - primarily for multigpu cloning purposes.'
|
||||||
|
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
|
||||||
|
additional_models = []
|
||||||
|
if unload_additional_models:
|
||||||
|
additional_models = model.get_nested_additional_models()
|
||||||
|
keep_loaded = []
|
||||||
|
for loaded_model in initial_keep_loaded:
|
||||||
|
if loaded_model.model is not None:
|
||||||
|
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||||
|
continue
|
||||||
|
# check additional models if they are a match
|
||||||
|
skip = False
|
||||||
|
for add_model in additional_models:
|
||||||
|
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||||
|
skip = True
|
||||||
|
break
|
||||||
|
if skip:
|
||||||
|
continue
|
||||||
|
keep_loaded.append(loaded_model)
|
||||||
|
if not all_devices:
|
||||||
|
free_memory(1e30, get_torch_device(), keep_loaded)
|
||||||
|
else:
|
||||||
|
for device in get_all_torch_devices():
|
||||||
|
free_memory(1e30, device, keep_loaded)
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
#TODO: might be cleaner to put this somewhere else
|
||||||
import threading
|
import threading
|
||||||
|
|||||||
@@ -84,12 +84,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
|||||||
def create_model_options_clone(orig_model_options: dict):
|
def create_model_options_clone(orig_model_options: dict):
|
||||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||||
|
|
||||||
def create_hook_patches_clone(orig_hook_patches):
|
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
|
||||||
new_hook_patches = {}
|
new_hook_patches = {}
|
||||||
for hook_ref in orig_hook_patches:
|
for hook_ref in orig_hook_patches:
|
||||||
new_hook_patches[hook_ref] = {}
|
new_hook_patches[hook_ref] = {}
|
||||||
for k in orig_hook_patches[hook_ref]:
|
for k in orig_hook_patches[hook_ref]:
|
||||||
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
||||||
|
if copy_tuples:
|
||||||
|
for i in range(len(new_hook_patches[hook_ref][k])):
|
||||||
|
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
|
||||||
return new_hook_patches
|
return new_hook_patches
|
||||||
|
|
||||||
def wipe_lowvram_weight(m):
|
def wipe_lowvram_weight(m):
|
||||||
@@ -240,6 +243,9 @@ class ModelPatcher:
|
|||||||
self.is_clip = False
|
self.is_clip = False
|
||||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||||
|
|
||||||
|
self.is_multigpu_base_clone = False
|
||||||
|
self.clone_base_uuid = uuid.uuid4()
|
||||||
|
|
||||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
@@ -318,16 +324,90 @@ class ModelPatcher:
|
|||||||
n.is_clip = self.is_clip
|
n.is_clip = self.is_clip
|
||||||
n.hook_mode = self.hook_mode
|
n.hook_mode = self.hook_mode
|
||||||
|
|
||||||
|
n.is_multigpu_base_clone = self.is_multigpu_base_clone
|
||||||
|
n.clone_base_uuid = self.clone_base_uuid
|
||||||
|
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||||
callback(self, n)
|
callback(self, n)
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
|
||||||
|
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
|
||||||
|
comfy.model_management.unload_model_and_clones(self)
|
||||||
|
n = self.clone()
|
||||||
|
# set load device, if present
|
||||||
|
if new_load_device is not None:
|
||||||
|
n.load_device = new_load_device
|
||||||
|
# unlike for normal clone, backup dicts that shared same ref should not;
|
||||||
|
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
|
||||||
|
n.backup = copy.deepcopy(n.backup)
|
||||||
|
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
|
||||||
|
n.hook_backup = copy.deepcopy(n.hook_backup)
|
||||||
|
n.model = copy.deepcopy(n.model)
|
||||||
|
# multigpu clone should not have multigpu additional_models entry
|
||||||
|
n.remove_additional_models("multigpu")
|
||||||
|
# multigpu_clone all stored additional_models; make sure circular references are properly handled
|
||||||
|
if models_cache is None:
|
||||||
|
models_cache = {}
|
||||||
|
for key, model_list in n.additional_models.items():
|
||||||
|
for i in range(len(model_list)):
|
||||||
|
add_model = n.additional_models[key][i]
|
||||||
|
if add_model.clone_base_uuid not in models_cache:
|
||||||
|
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
|
||||||
|
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
|
||||||
|
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
|
||||||
|
callback(self, n)
|
||||||
|
return n
|
||||||
|
|
||||||
|
def match_multigpu_clones(self):
|
||||||
|
multigpu_models = self.get_additional_models_with_key("multigpu")
|
||||||
|
if len(multigpu_models) > 0:
|
||||||
|
new_multigpu_models = []
|
||||||
|
for mm in multigpu_models:
|
||||||
|
# clone main model, but bring over relevant props from existing multigpu clone
|
||||||
|
n = self.clone()
|
||||||
|
n.load_device = mm.load_device
|
||||||
|
n.backup = mm.backup
|
||||||
|
n.object_patches_backup = mm.object_patches_backup
|
||||||
|
n.hook_backup = mm.hook_backup
|
||||||
|
n.model = mm.model
|
||||||
|
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
|
||||||
|
n.remove_additional_models("multigpu")
|
||||||
|
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
|
||||||
|
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
|
||||||
|
# figure out which additional models are not present in multigpu clone
|
||||||
|
models_cache = {}
|
||||||
|
for mm_add_model in mm.get_additional_models():
|
||||||
|
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
|
||||||
|
remove_models_uuids = set(list(models_cache.keys()))
|
||||||
|
for key, model_list in orig_additional_models.items():
|
||||||
|
for orig_add_model in model_list:
|
||||||
|
if orig_add_model.clone_base_uuid not in models_cache:
|
||||||
|
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
|
||||||
|
existing_list = n.get_additional_models_with_key(key)
|
||||||
|
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
|
||||||
|
n.set_additional_models(key, existing_list)
|
||||||
|
if orig_add_model.clone_base_uuid in remove_models_uuids:
|
||||||
|
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
|
||||||
|
# remove duplicate additional models
|
||||||
|
for key, model_list in n.additional_models.items():
|
||||||
|
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
|
||||||
|
n.set_additional_models(key, new_model_list)
|
||||||
|
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
|
||||||
|
callback(self, n)
|
||||||
|
new_multigpu_models.append(n)
|
||||||
|
self.set_additional_models("multigpu", new_multigpu_models)
|
||||||
|
|
||||||
def is_clone(self, other):
|
def is_clone(self, other):
|
||||||
if hasattr(other, 'model') and self.model is other.model:
|
if hasattr(other, 'model') and self.model is other.model:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
|
||||||
|
if allow_multigpu:
|
||||||
|
if self.clone_base_uuid != clone.clone_base_uuid:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
if not self.is_clone(clone):
|
if not self.is_clone(clone):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -929,7 +1009,7 @@ class ModelPatcher:
|
|||||||
return self.additional_models.get(key, [])
|
return self.additional_models.get(key, [])
|
||||||
|
|
||||||
def get_additional_models(self):
|
def get_additional_models(self):
|
||||||
all_models = []
|
all_models: list[ModelPatcher] = []
|
||||||
for models in self.additional_models.values():
|
for models in self.additional_models.values():
|
||||||
all_models.extend(models)
|
all_models.extend(models)
|
||||||
return all_models
|
return all_models
|
||||||
@@ -983,9 +1063,13 @@ class ModelPatcher:
|
|||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
||||||
callback(self)
|
callback(self)
|
||||||
|
|
||||||
def prepare_state(self, timestep):
|
def prepare_state(self, timestep, model_options, ignore_multigpu=False):
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
||||||
callback(self, timestep)
|
callback(self, timestep, model_options, ignore_multigpu)
|
||||||
|
if not ignore_multigpu and "multigpu_clones" in model_options:
|
||||||
|
for p in model_options["multigpu_clones"].values():
|
||||||
|
p: ModelPatcher
|
||||||
|
p.prepare_state(timestep, model_options, ignore_multigpu=True)
|
||||||
|
|
||||||
def restore_hook_patches(self):
|
def restore_hook_patches(self):
|
||||||
if self.hook_patches_backup is not None:
|
if self.hook_patches_backup is not None:
|
||||||
@@ -998,12 +1082,18 @@ class ModelPatcher:
|
|||||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
||||||
curr_t = t[0]
|
curr_t = t[0]
|
||||||
reset_current_hooks = False
|
reset_current_hooks = False
|
||||||
|
multigpu_kf_changed_cache = None
|
||||||
transformer_options = model_options.get("transformer_options", {})
|
transformer_options = model_options.get("transformer_options", {})
|
||||||
for hook in hook_group.hooks:
|
for hook in hook_group.hooks:
|
||||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
||||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||||
# this will cause the weights to be recalculated when sampling
|
# this will cause the weights to be recalculated when sampling
|
||||||
if changed:
|
if changed:
|
||||||
|
# cache changed for multigpu usage
|
||||||
|
if "multigpu_clones" in model_options:
|
||||||
|
if multigpu_kf_changed_cache is None:
|
||||||
|
multigpu_kf_changed_cache = []
|
||||||
|
multigpu_kf_changed_cache.append(hook)
|
||||||
# reset current_hooks if contains hook that changed
|
# reset current_hooks if contains hook that changed
|
||||||
if self.current_hooks is not None:
|
if self.current_hooks is not None:
|
||||||
for current_hook in self.current_hooks.hooks:
|
for current_hook in self.current_hooks.hooks:
|
||||||
@@ -1015,6 +1105,28 @@ class ModelPatcher:
|
|||||||
self.cached_hook_patches.pop(cached_group)
|
self.cached_hook_patches.pop(cached_group)
|
||||||
if reset_current_hooks:
|
if reset_current_hooks:
|
||||||
self.patch_hooks(None)
|
self.patch_hooks(None)
|
||||||
|
if "multigpu_clones" in model_options:
|
||||||
|
for p in model_options["multigpu_clones"].values():
|
||||||
|
p: ModelPatcher
|
||||||
|
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
|
||||||
|
|
||||||
|
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
|
||||||
|
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
|
||||||
|
if kf_changed_cache is None:
|
||||||
|
return
|
||||||
|
reset_current_hooks = False
|
||||||
|
# reset current_hooks if contains hook that changed
|
||||||
|
for hook in kf_changed_cache:
|
||||||
|
if self.current_hooks is not None:
|
||||||
|
for current_hook in self.current_hooks.hooks:
|
||||||
|
if current_hook == hook:
|
||||||
|
reset_current_hooks = True
|
||||||
|
break
|
||||||
|
for cached_group in list(self.cached_hook_patches.keys()):
|
||||||
|
if cached_group.contains(hook):
|
||||||
|
self.cached_hook_patches.pop(cached_group)
|
||||||
|
if reset_current_hooks:
|
||||||
|
self.patch_hooks(None)
|
||||||
|
|
||||||
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||||
registered: comfy.hooks.HookGroup = None):
|
registered: comfy.hooks.HookGroup = None):
|
||||||
|
|||||||
167
comfy/multigpu.py
Normal file
167
comfy/multigpu.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
import comfy.utils
|
||||||
|
import comfy.patcher_extension
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
|
class GPUOptions:
|
||||||
|
def __init__(self, device_index: int, relative_speed: float):
|
||||||
|
self.device_index = device_index
|
||||||
|
self.relative_speed = relative_speed
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
return GPUOptions(self.device_index, self.relative_speed)
|
||||||
|
|
||||||
|
def create_dict(self):
|
||||||
|
return {
|
||||||
|
"relative_speed": self.relative_speed
|
||||||
|
}
|
||||||
|
|
||||||
|
class GPUOptionsGroup:
|
||||||
|
def __init__(self):
|
||||||
|
self.options: dict[int, GPUOptions] = {}
|
||||||
|
|
||||||
|
def add(self, info: GPUOptions):
|
||||||
|
self.options[info.device_index] = info
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
c = GPUOptionsGroup()
|
||||||
|
for opt in self.options.values():
|
||||||
|
c.add(opt)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def register(self, model: ModelPatcher):
|
||||||
|
opts_dict = {}
|
||||||
|
# get devices that are valid for this model
|
||||||
|
devices: list[torch.device] = [model.load_device]
|
||||||
|
for extra_model in model.get_additional_models_with_key("multigpu"):
|
||||||
|
extra_model: ModelPatcher
|
||||||
|
devices.append(extra_model.load_device)
|
||||||
|
# create dictionary with actual device mapped to its GPUOptions
|
||||||
|
device_opts_list: list[GPUOptions] = []
|
||||||
|
for device in devices:
|
||||||
|
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
|
||||||
|
opts_dict[device] = device_opts.create_dict()
|
||||||
|
device_opts_list.append(device_opts)
|
||||||
|
# make relative_speed relative to 1.0
|
||||||
|
min_speed = min([x.relative_speed for x in device_opts_list])
|
||||||
|
for value in opts_dict.values():
|
||||||
|
value['relative_speed'] /= min_speed
|
||||||
|
model.model_options['multigpu_options'] = opts_dict
|
||||||
|
|
||||||
|
|
||||||
|
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
|
||||||
|
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
|
||||||
|
model = model.clone()
|
||||||
|
# check if multigpu is already prepared - get the load devices from them if possible to exclude
|
||||||
|
skip_devices = set()
|
||||||
|
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||||
|
if len(multigpu_models) > 0:
|
||||||
|
for mm in multigpu_models:
|
||||||
|
skip_devices.add(mm.load_device)
|
||||||
|
skip_devices = list(skip_devices)
|
||||||
|
|
||||||
|
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
||||||
|
limit_extra_devices = full_extra_devices[:max_gpus-1]
|
||||||
|
extra_devices = limit_extra_devices.copy()
|
||||||
|
# exclude skipped devices
|
||||||
|
for skip in skip_devices:
|
||||||
|
if skip in extra_devices:
|
||||||
|
extra_devices.remove(skip)
|
||||||
|
# create new deepclones
|
||||||
|
if len(extra_devices) > 0:
|
||||||
|
for device in extra_devices:
|
||||||
|
device_patcher = None
|
||||||
|
if reuse_loaded:
|
||||||
|
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
|
||||||
|
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
|
||||||
|
for lm in loaded_models:
|
||||||
|
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
|
||||||
|
device_patcher = lm.clone()
|
||||||
|
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
|
||||||
|
break
|
||||||
|
if device_patcher is None:
|
||||||
|
device_patcher = model.deepclone_multigpu(new_load_device=device)
|
||||||
|
device_patcher.is_multigpu_base_clone = True
|
||||||
|
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||||
|
multigpu_models.append(device_patcher)
|
||||||
|
model.set_additional_models("multigpu", multigpu_models)
|
||||||
|
model.match_multigpu_clones()
|
||||||
|
if gpu_options is None:
|
||||||
|
gpu_options = GPUOptionsGroup()
|
||||||
|
gpu_options.register(model)
|
||||||
|
else:
|
||||||
|
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
|
||||||
|
# TODO: only keep model clones that don't go 'past' the intended max_gpu count
|
||||||
|
# multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||||
|
# new_multigpu_models = []
|
||||||
|
# for m in multigpu_models:
|
||||||
|
# if m.load_device in limit_extra_devices:
|
||||||
|
# new_multigpu_models.append(m)
|
||||||
|
# model.set_additional_models("multigpu", new_multigpu_models)
|
||||||
|
# persist skip_devices for use in sampling code
|
||||||
|
# if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options:
|
||||||
|
# model.model_options["multigpu_skip_devices"] = skip_devices
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
||||||
|
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
||||||
|
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
||||||
|
opts_dict = model_options['multigpu_options']
|
||||||
|
devices = list(model_options['multigpu_clones'].keys())
|
||||||
|
speed_per_device = []
|
||||||
|
work_per_device = []
|
||||||
|
# get sum of each device's relative_speed
|
||||||
|
total_speed = 0.0
|
||||||
|
for opts in opts_dict.values():
|
||||||
|
total_speed += opts['relative_speed']
|
||||||
|
# get relative work for each device;
|
||||||
|
# obtained by w = (W*r)/R
|
||||||
|
for device in devices:
|
||||||
|
relative_speed = opts_dict[device]['relative_speed']
|
||||||
|
relative_work = (total_work*relative_speed) / total_speed
|
||||||
|
speed_per_device.append(relative_speed)
|
||||||
|
work_per_device.append(relative_work)
|
||||||
|
# relative work must be expressed in whole numbers, but likely is a decimal;
|
||||||
|
# perform rounding while maintaining total sum equal to total work (sum of relative works)
|
||||||
|
work_per_device = round_preserved(work_per_device)
|
||||||
|
dict_work_per_device = {}
|
||||||
|
for device, relative_work in zip(devices, work_per_device):
|
||||||
|
dict_work_per_device[device] = relative_work
|
||||||
|
if not return_idle_time:
|
||||||
|
return LoadBalance(dict_work_per_device, None)
|
||||||
|
# divide relative work by relative speed to get estimated completion time of said work by each device;
|
||||||
|
# time here is relative and does not correspond to real-world units
|
||||||
|
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
|
||||||
|
# calculate relative time spent by the devices waiting on each other after their work is completed
|
||||||
|
idle_time = abs(min(completion_time) - max(completion_time))
|
||||||
|
# if need to compare work idle time, need to normalize to a common total work
|
||||||
|
if work_normalized:
|
||||||
|
idle_time *= (work_normalized/total_work)
|
||||||
|
|
||||||
|
return LoadBalance(dict_work_per_device, idle_time)
|
||||||
|
|
||||||
|
def round_preserved(values: list[float]):
|
||||||
|
'Round all values in a list, preserving the combined sum of values.'
|
||||||
|
# get floor of values; casting to int does it too
|
||||||
|
floored = [int(x) for x in values]
|
||||||
|
total_floored = sum(floored)
|
||||||
|
# get remainder to distribute
|
||||||
|
remainder = round(sum(values)) - total_floored
|
||||||
|
# pair values with fractional portions
|
||||||
|
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
|
||||||
|
# sort by fractional part in descending order
|
||||||
|
fractional.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
# distribute the remainder
|
||||||
|
for i in range(remainder):
|
||||||
|
index = fractional[i][0]
|
||||||
|
floored[index] += 1
|
||||||
|
return floored
|
||||||
@@ -3,6 +3,8 @@ from typing import Callable
|
|||||||
|
|
||||||
class CallbacksMP:
|
class CallbacksMP:
|
||||||
ON_CLONE = "on_clone"
|
ON_CLONE = "on_clone"
|
||||||
|
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
|
||||||
|
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
|
||||||
ON_LOAD = "on_load_after"
|
ON_LOAD = "on_load_after"
|
||||||
ON_DETACH = "on_detach_after"
|
ON_DETACH = "on_detach_after"
|
||||||
ON_CLEANUP = "on_cleanup"
|
ON_CLEANUP = "on_cleanup"
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import torch
|
||||||
import uuid
|
import uuid
|
||||||
import math
|
import math
|
||||||
import collections
|
import collections
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.conds
|
import comfy.conds
|
||||||
|
import comfy.model_patcher
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@@ -106,6 +108,47 @@ def cleanup_additional_models(models):
|
|||||||
if hasattr(m, 'cleanup'):
|
if hasattr(m, 'cleanup'):
|
||||||
m.cleanup()
|
m.cleanup()
|
||||||
|
|
||||||
|
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
|
||||||
|
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
|
||||||
|
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
|
||||||
|
if len(multigpu_models) == 0:
|
||||||
|
return
|
||||||
|
extra_devices = [x.load_device for x in multigpu_models]
|
||||||
|
# handle controlnets
|
||||||
|
controlnets: set[ControlBase] = set()
|
||||||
|
for k in conds:
|
||||||
|
for kk in conds[k]:
|
||||||
|
if 'control' in kk:
|
||||||
|
controlnets.add(kk['control'])
|
||||||
|
if len(controlnets) > 0:
|
||||||
|
# first, unload all controlnet clones
|
||||||
|
for cnet in list(controlnets):
|
||||||
|
cnet_models = cnet.get_models()
|
||||||
|
for cm in cnet_models:
|
||||||
|
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
|
||||||
|
|
||||||
|
# next, make sure each controlnet has a deepclone for all relevant devices
|
||||||
|
for cnet in controlnets:
|
||||||
|
curr_cnet = cnet
|
||||||
|
while curr_cnet is not None:
|
||||||
|
for device in extra_devices:
|
||||||
|
if device not in curr_cnet.multigpu_clones:
|
||||||
|
curr_cnet.deepclone_multigpu(device, autoregister=True)
|
||||||
|
curr_cnet = curr_cnet.previous_controlnet
|
||||||
|
# since all device clones are now present, recreate the linked list for cloned cnets per device
|
||||||
|
for cnet in controlnets:
|
||||||
|
curr_cnet = cnet
|
||||||
|
while curr_cnet is not None:
|
||||||
|
prev_cnet = curr_cnet.previous_controlnet
|
||||||
|
for device in extra_devices:
|
||||||
|
device_cnet = curr_cnet.get_instance_for_device(device)
|
||||||
|
prev_device_cnet = None
|
||||||
|
if prev_cnet is not None:
|
||||||
|
prev_device_cnet = prev_cnet.get_instance_for_device(device)
|
||||||
|
device_cnet.set_previous_controlnet(prev_device_cnet)
|
||||||
|
curr_cnet = prev_cnet
|
||||||
|
# potentially handle gligen - since not widely used, ignored for now
|
||||||
|
|
||||||
def estimate_memory(model, noise_shape, conds):
|
def estimate_memory(model, noise_shape, conds):
|
||||||
cond_shapes = collections.defaultdict(list)
|
cond_shapes = collections.defaultdict(list)
|
||||||
cond_shapes_min = {}
|
cond_shapes_min = {}
|
||||||
@@ -130,7 +173,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
|
|||||||
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
||||||
|
|
||||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||||
real_model: BaseModel = None
|
model.match_multigpu_clones()
|
||||||
|
preprocess_multigpu_conds(conds, model, model_options)
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
models += get_additional_models_from_model_options(model_options)
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
@@ -149,7 +193,7 @@ def cleanup_models(conds, models):
|
|||||||
|
|
||||||
cleanup_additional_models(set(control_cleanup))
|
cleanup_additional_models(set(control_cleanup))
|
||||||
|
|
||||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
||||||
'''
|
'''
|
||||||
Registers hooks from conds.
|
Registers hooks from conds.
|
||||||
'''
|
'''
|
||||||
@@ -182,3 +226,18 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
|||||||
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||||
copy_dict1=False)
|
copy_dict1=False)
|
||||||
return to_load_options
|
return to_load_options
|
||||||
|
|
||||||
|
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
|
||||||
|
'''
|
||||||
|
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
|
||||||
|
'''
|
||||||
|
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
|
||||||
|
if len(multigpu_patchers) > 0:
|
||||||
|
multigpu_dict: dict[torch.device, ModelPatcher] = {}
|
||||||
|
multigpu_dict[model_patcher.load_device] = model_patcher
|
||||||
|
for x in multigpu_patchers:
|
||||||
|
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
|
||||||
|
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
|
||||||
|
multigpu_dict[x.load_device] = x
|
||||||
|
model_options["multigpu_clones"] = multigpu_dict
|
||||||
|
return multigpu_patchers
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
from .k_diffusion import sampling as k_diffusion_sampling
|
from .k_diffusion import sampling as k_diffusion_sampling
|
||||||
from .extra_samplers import uni_pc
|
from .extra_samplers import uni_pc
|
||||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||||
@@ -18,6 +20,7 @@ import comfy.patcher_extension
|
|||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
def add_area_dims(area, num_dims):
|
def add_area_dims(area, num_dims):
|
||||||
@@ -140,7 +143,7 @@ def can_concat_cond(c1, c2):
|
|||||||
|
|
||||||
return cond_equal_size(c1.conditioning, c2.conditioning)
|
return cond_equal_size(c1.conditioning, c2.conditioning)
|
||||||
|
|
||||||
def cond_cat(c_list):
|
def cond_cat(c_list, device=None):
|
||||||
temp = {}
|
temp = {}
|
||||||
for x in c_list:
|
for x in c_list:
|
||||||
for k in x:
|
for k in x:
|
||||||
@@ -152,6 +155,8 @@ def cond_cat(c_list):
|
|||||||
for k in temp:
|
for k in temp:
|
||||||
conds = temp[k]
|
conds = temp[k]
|
||||||
out[k] = conds[0].concat(conds[1:])
|
out[k] = conds[0].concat(conds[1:])
|
||||||
|
if device is not None and hasattr(out[k], 'to'):
|
||||||
|
out[k] = out[k].to(device)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -205,7 +210,9 @@ def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Ten
|
|||||||
)
|
)
|
||||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||||
|
|
||||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
|
if 'multigpu_clones' in model_options:
|
||||||
|
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
|
||||||
out_conds = []
|
out_conds = []
|
||||||
out_counts = []
|
out_counts = []
|
||||||
# separate conds by matching hooks
|
# separate conds by matching hooks
|
||||||
@@ -237,7 +244,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
|||||||
if has_default_conds:
|
if has_default_conds:
|
||||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||||
|
|
||||||
model.current_patcher.prepare_state(timestep)
|
model.current_patcher.prepare_state(timestep, model_options)
|
||||||
|
|
||||||
# run every hooked_to_run separately
|
# run every hooked_to_run separately
|
||||||
for hooks, to_run in hooked_to_run.items():
|
for hooks, to_run in hooked_to_run.items():
|
||||||
@@ -345,6 +352,190 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
|||||||
|
|
||||||
return out_conds
|
return out_conds
|
||||||
|
|
||||||
|
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
|
out_conds = []
|
||||||
|
out_counts = []
|
||||||
|
# separate conds by matching hooks
|
||||||
|
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||||
|
default_conds = []
|
||||||
|
has_default_conds = False
|
||||||
|
|
||||||
|
output_device = x_in.device
|
||||||
|
|
||||||
|
for i in range(len(conds)):
|
||||||
|
out_conds.append(torch.zeros_like(x_in))
|
||||||
|
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||||
|
|
||||||
|
cond = conds[i]
|
||||||
|
default_c = []
|
||||||
|
if cond is not None:
|
||||||
|
for x in cond:
|
||||||
|
if 'default' in x:
|
||||||
|
default_c.append(x)
|
||||||
|
has_default_conds = True
|
||||||
|
continue
|
||||||
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
|
if p is None:
|
||||||
|
continue
|
||||||
|
if p.hooks is not None:
|
||||||
|
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||||
|
hooked_to_run.setdefault(p.hooks, list())
|
||||||
|
hooked_to_run[p.hooks] += [(p, i)]
|
||||||
|
default_conds.append(default_c)
|
||||||
|
|
||||||
|
if has_default_conds:
|
||||||
|
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||||
|
|
||||||
|
model.current_patcher.prepare_state(timestep, model_options)
|
||||||
|
|
||||||
|
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
|
||||||
|
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
|
||||||
|
|
||||||
|
total_conds = 0
|
||||||
|
for to_run in hooked_to_run.values():
|
||||||
|
total_conds += len(to_run)
|
||||||
|
conds_per_device = max(1, math.ceil(total_conds//len(devices)))
|
||||||
|
index_device = 0
|
||||||
|
current_device = devices[index_device]
|
||||||
|
# run every hooked_to_run separately
|
||||||
|
for hooks, to_run in hooked_to_run.items():
|
||||||
|
while len(to_run) > 0:
|
||||||
|
current_device = devices[index_device % len(devices)]
|
||||||
|
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
|
||||||
|
# keep track of conds currently scheduled onto this device
|
||||||
|
batched_to_run_length = 0
|
||||||
|
for btr in batched_to_run:
|
||||||
|
batched_to_run_length += len(btr[1])
|
||||||
|
|
||||||
|
first = to_run[0]
|
||||||
|
first_shape = first[0][0].shape
|
||||||
|
to_batch_temp = []
|
||||||
|
# make sure not over conds_per_device limit when creating temp batch
|
||||||
|
for x in range(len(to_run)):
|
||||||
|
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length):
|
||||||
|
to_batch_temp += [x]
|
||||||
|
|
||||||
|
to_batch_temp.reverse()
|
||||||
|
to_batch = to_batch_temp[:1]
|
||||||
|
|
||||||
|
free_memory = model_management.get_free_memory(current_device)
|
||||||
|
for i in range(1, len(to_batch_temp) + 1):
|
||||||
|
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||||
|
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||||
|
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||||
|
to_batch = batch_amount
|
||||||
|
break
|
||||||
|
conds_to_batch = []
|
||||||
|
for x in to_batch:
|
||||||
|
conds_to_batch.append(to_run.pop(x))
|
||||||
|
batched_to_run_length += len(conds_to_batch)
|
||||||
|
|
||||||
|
batched_to_run.append((hooks, conds_to_batch))
|
||||||
|
if batched_to_run_length >= conds_per_device:
|
||||||
|
index_device += 1
|
||||||
|
|
||||||
|
thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond'])
|
||||||
|
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||||
|
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||||
|
# run every hooked_to_run separately
|
||||||
|
with torch.no_grad():
|
||||||
|
for hooks, to_batch in batch_tuple:
|
||||||
|
input_x = []
|
||||||
|
mult = []
|
||||||
|
c = []
|
||||||
|
cond_or_uncond = []
|
||||||
|
uuids = []
|
||||||
|
area = []
|
||||||
|
control: ControlBase = None
|
||||||
|
patches = None
|
||||||
|
for x in to_batch:
|
||||||
|
o = x
|
||||||
|
p = o[0]
|
||||||
|
input_x.append(p.input_x)
|
||||||
|
mult.append(p.mult)
|
||||||
|
c.append(p.conditioning)
|
||||||
|
area.append(p.area)
|
||||||
|
cond_or_uncond.append(o[1])
|
||||||
|
uuids.append(p.uuid)
|
||||||
|
control = p.control
|
||||||
|
patches = p.patches
|
||||||
|
|
||||||
|
batch_chunks = len(cond_or_uncond)
|
||||||
|
input_x = torch.cat(input_x).to(device)
|
||||||
|
c = cond_cat(c, device=device)
|
||||||
|
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
|
||||||
|
|
||||||
|
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
|
||||||
|
if 'transformer_options' in model_options:
|
||||||
|
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||||
|
model_options['transformer_options'],
|
||||||
|
copy_dict1=False)
|
||||||
|
|
||||||
|
if patches is not None:
|
||||||
|
# TODO: replace with merge_nested_dicts function
|
||||||
|
if "patches" in transformer_options:
|
||||||
|
cur_patches = transformer_options["patches"].copy()
|
||||||
|
for p in patches:
|
||||||
|
if p in cur_patches:
|
||||||
|
cur_patches[p] = cur_patches[p] + patches[p]
|
||||||
|
else:
|
||||||
|
cur_patches[p] = patches[p]
|
||||||
|
transformer_options["patches"] = cur_patches
|
||||||
|
else:
|
||||||
|
transformer_options["patches"] = patches
|
||||||
|
|
||||||
|
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||||
|
transformer_options["uuids"] = uuids[:]
|
||||||
|
transformer_options["sigmas"] = timestep
|
||||||
|
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
|
||||||
|
transformer_options["multigpu_thread_device"] = device
|
||||||
|
|
||||||
|
cast_transformer_options(transformer_options, device=device)
|
||||||
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
|
if control is not None:
|
||||||
|
device_control = control.get_instance_for_device(device)
|
||||||
|
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||||
|
|
||||||
|
if 'model_function_wrapper' in model_options:
|
||||||
|
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
||||||
|
else:
|
||||||
|
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
|
||||||
|
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
|
||||||
|
|
||||||
|
|
||||||
|
results: list[thread_result] = []
|
||||||
|
threads: list[threading.Thread] = []
|
||||||
|
for device, batch_tuple in device_batched_hooked_to_run.items():
|
||||||
|
new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results))
|
||||||
|
threads.append(new_thread)
|
||||||
|
new_thread.start()
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
for output, mult, area, batch_chunks, cond_or_uncond in results:
|
||||||
|
for o in range(batch_chunks):
|
||||||
|
cond_index = cond_or_uncond[o]
|
||||||
|
a = area[o]
|
||||||
|
if a is None:
|
||||||
|
out_conds[cond_index] += output[o] * mult[o]
|
||||||
|
out_counts[cond_index] += mult[o]
|
||||||
|
else:
|
||||||
|
out_c = out_conds[cond_index]
|
||||||
|
out_cts = out_counts[cond_index]
|
||||||
|
dims = len(a) // 2
|
||||||
|
for i in range(dims):
|
||||||
|
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||||
|
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||||
|
out_c += output[o] * mult[o]
|
||||||
|
out_cts += mult[o]
|
||||||
|
|
||||||
|
for i in range(len(out_conds)):
|
||||||
|
out_conds[i] /= out_counts[i]
|
||||||
|
|
||||||
|
return out_conds
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||||
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||||
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||||
@@ -642,6 +833,8 @@ def pre_run_control(model, conds):
|
|||||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||||
if 'control' in x:
|
if 'control' in x:
|
||||||
x['control'].pre_run(model, percent_to_timestep_function)
|
x['control'].pre_run(model, percent_to_timestep_function)
|
||||||
|
for device_cnet in x['control'].multigpu_clones.values():
|
||||||
|
device_cnet.pre_run(model, percent_to_timestep_function)
|
||||||
|
|
||||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||||
cond_cnets = []
|
cond_cnets = []
|
||||||
@@ -884,7 +1077,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
|||||||
to_load_options = model_options.get("to_load_options", None)
|
to_load_options = model_options.get("to_load_options", None)
|
||||||
if to_load_options is None:
|
if to_load_options is None:
|
||||||
return
|
return
|
||||||
|
cast_transformer_options(to_load_options, device, dtype)
|
||||||
|
|
||||||
|
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
|
||||||
casts = []
|
casts = []
|
||||||
if device is not None:
|
if device is not None:
|
||||||
casts.append(device)
|
casts.append(device)
|
||||||
@@ -893,18 +1088,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
|||||||
# if nothing to apply, do nothing
|
# if nothing to apply, do nothing
|
||||||
if len(casts) == 0:
|
if len(casts) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# try to call .to on patches
|
# try to call .to on patches
|
||||||
if "patches" in to_load_options:
|
if "patches" in transformer_options:
|
||||||
patches = to_load_options["patches"]
|
patches = transformer_options["patches"]
|
||||||
for name in patches:
|
for name in patches:
|
||||||
patch_list = patches[name]
|
patch_list = patches[name]
|
||||||
for i in range(len(patch_list)):
|
for i in range(len(patch_list)):
|
||||||
if hasattr(patch_list[i], "to"):
|
if hasattr(patch_list[i], "to"):
|
||||||
for cast in casts:
|
for cast in casts:
|
||||||
patch_list[i] = patch_list[i].to(cast)
|
patch_list[i] = patch_list[i].to(cast)
|
||||||
if "patches_replace" in to_load_options:
|
if "patches_replace" in transformer_options:
|
||||||
patches = to_load_options["patches_replace"]
|
patches = transformer_options["patches_replace"]
|
||||||
for name in patches:
|
for name in patches:
|
||||||
patch_list = patches[name]
|
patch_list = patches[name]
|
||||||
for k in patch_list:
|
for k in patch_list:
|
||||||
@@ -914,8 +1108,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
|||||||
# try to call .to on any wrappers/callbacks
|
# try to call .to on any wrappers/callbacks
|
||||||
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||||
for wc_name in wrappers_and_callbacks:
|
for wc_name in wrappers_and_callbacks:
|
||||||
if wc_name in to_load_options:
|
if wc_name in transformer_options:
|
||||||
wc: dict[str, list] = to_load_options[wc_name]
|
wc: dict[str, list] = transformer_options[wc_name]
|
||||||
for wc_dict in wc.values():
|
for wc_dict in wc.values():
|
||||||
for wc_list in wc_dict.values():
|
for wc_list in wc_dict.values():
|
||||||
for i in range(len(wc_list)):
|
for i in range(len(wc_list)):
|
||||||
@@ -923,7 +1117,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
|||||||
for cast in casts:
|
for cast in casts:
|
||||||
wc_list[i] = wc_list[i].to(cast)
|
wc_list[i] = wc_list[i].to(cast)
|
||||||
|
|
||||||
|
|
||||||
class CFGGuider:
|
class CFGGuider:
|
||||||
def __init__(self, model_patcher: ModelPatcher):
|
def __init__(self, model_patcher: ModelPatcher):
|
||||||
self.model_patcher = model_patcher
|
self.model_patcher = model_patcher
|
||||||
@@ -969,6 +1162,8 @@ class CFGGuider:
|
|||||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
|
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
|
||||||
|
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
||||||
|
|
||||||
@@ -979,9 +1174,13 @@ class CFGGuider:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
self.model_patcher.pre_run()
|
self.model_patcher.pre_run()
|
||||||
|
for multigpu_patcher in multigpu_patchers:
|
||||||
|
multigpu_patcher.pre_run()
|
||||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
finally:
|
finally:
|
||||||
self.model_patcher.cleanup()
|
self.model_patcher.cleanup()
|
||||||
|
for multigpu_patcher in multigpu_patchers:
|
||||||
|
multigpu_patcher.cleanup()
|
||||||
|
|
||||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||||
del self.inner_model
|
del self.inner_model
|
||||||
|
|||||||
@@ -49,16 +49,10 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
|||||||
else:
|
else:
|
||||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||||
|
|
||||||
def is_html_file(file_path):
|
|
||||||
with open(file_path, "rb") as f:
|
|
||||||
content = f.read(100)
|
|
||||||
return b"<!DOCTYPE html>" in content or b"<html" in content
|
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||||
if device is None:
|
if device is None:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
metadata = None
|
metadata = None
|
||||||
|
|
||||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||||
try:
|
try:
|
||||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||||
@@ -68,8 +62,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
|||||||
if return_metadata:
|
if return_metadata:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if is_html_file(ckpt):
|
|
||||||
raise ValueError("{}\n\nFile path: {}\n\nThe requested file is an HTML document not a safetensors file. Please re-download the file, not the web page.".format(e, ckpt))
|
|
||||||
if len(e.args) > 0:
|
if len(e.args) > 0:
|
||||||
message = e.args[0]
|
message = e.args[0]
|
||||||
if "HeaderTooLarge" in message:
|
if "HeaderTooLarge" in message:
|
||||||
@@ -96,13 +88,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
|||||||
sd = pl_sd
|
sd = pl_sd
|
||||||
else:
|
else:
|
||||||
sd = pl_sd
|
sd = pl_sd
|
||||||
|
|
||||||
try:
|
|
||||||
from app.model_processor import model_processor
|
|
||||||
model_processor.process_file(ckpt)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error processing file {ckpt}: {e}")
|
|
||||||
|
|
||||||
return (sd, metadata) if return_metadata else sd
|
return (sd, metadata) if return_metadata else sd
|
||||||
|
|
||||||
def save_torch_file(sd, ckpt, metadata=None):
|
def save_torch_file(sd, ckpt, metadata=None):
|
||||||
|
|||||||
86
comfy_extras/nodes_multigpu.py
Normal file
86
comfy_extras/nodes_multigpu.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from inspect import cleandoc
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
import comfy.multigpu
|
||||||
|
|
||||||
|
|
||||||
|
class MultiGPUWorkUnitsNode:
|
||||||
|
"""
|
||||||
|
Prepares model to have sampling accelerated via splitting work units.
|
||||||
|
|
||||||
|
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
|
||||||
|
|
||||||
|
Other than those exceptions, this node can be placed in any order.
|
||||||
|
"""
|
||||||
|
|
||||||
|
NodeId = "MultiGPU_WorkUnits"
|
||||||
|
NodeName = "MultiGPU Work Units"
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL",),
|
||||||
|
"max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"gpu_options": ("GPU_OPTIONS",)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "init_multigpu"
|
||||||
|
CATEGORY = "advanced/multigpu"
|
||||||
|
DESCRIPTION = cleandoc(__doc__)
|
||||||
|
|
||||||
|
def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
|
||||||
|
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True)
|
||||||
|
return (model,)
|
||||||
|
|
||||||
|
class MultiGPUOptionsNode:
|
||||||
|
"""
|
||||||
|
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
|
||||||
|
"""
|
||||||
|
|
||||||
|
NodeId = "MultiGPU_Options"
|
||||||
|
NodeName = "MultiGPU Options"
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"device_index": ("INT", {"default": 0, "min": 0, "max": 64}),
|
||||||
|
"relative_speed": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.01})
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"gpu_options": ("GPU_OPTIONS",)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("GPU_OPTIONS",)
|
||||||
|
FUNCTION = "create_gpu_options"
|
||||||
|
CATEGORY = "advanced/multigpu"
|
||||||
|
DESCRIPTION = cleandoc(__doc__)
|
||||||
|
|
||||||
|
def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
|
||||||
|
if not gpu_options:
|
||||||
|
gpu_options = comfy.multigpu.GPUOptionsGroup()
|
||||||
|
gpu_options.clone()
|
||||||
|
|
||||||
|
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
|
||||||
|
gpu_options.add(opt)
|
||||||
|
|
||||||
|
return (gpu_options,)
|
||||||
|
|
||||||
|
|
||||||
|
node_list = [
|
||||||
|
MultiGPUWorkUnitsNode,
|
||||||
|
MultiGPUOptionsNode
|
||||||
|
]
|
||||||
|
NODE_CLASS_MAPPINGS = {}
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||||
|
|
||||||
|
for node in node_list:
|
||||||
|
NODE_CLASS_MAPPINGS[node.NodeId] = node
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName
|
||||||
@@ -275,7 +275,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_full_path(folder_name: str, filename: str, allow_missing: bool = False) -> str | None:
|
def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
if folder_name not in folder_names_and_paths:
|
if folder_name not in folder_names_and_paths:
|
||||||
@@ -288,8 +288,6 @@ def get_full_path(folder_name: str, filename: str, allow_missing: bool = False)
|
|||||||
return full_path
|
return full_path
|
||||||
elif os.path.islink(full_path):
|
elif os.path.islink(full_path):
|
||||||
logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path))
|
logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path))
|
||||||
elif allow_missing:
|
|
||||||
return full_path
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -301,27 +299,6 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
|||||||
return full_path
|
return full_path
|
||||||
|
|
||||||
|
|
||||||
def get_relative_path(full_path: str) -> tuple[str, str] | None:
|
|
||||||
"""Convert a full path back to a type-relative path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
full_path: The full path to the file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[str, str] | None: A tuple of (model_type, relative_path) if found, None otherwise
|
|
||||||
"""
|
|
||||||
global folder_names_and_paths
|
|
||||||
full_path = os.path.normpath(full_path)
|
|
||||||
|
|
||||||
for model_type, (paths, _) in folder_names_and_paths.items():
|
|
||||||
for base_path in paths:
|
|
||||||
base_path = os.path.normpath(base_path)
|
|
||||||
if full_path.startswith(base_path):
|
|
||||||
relative_path = os.path.relpath(full_path, base_path)
|
|
||||||
return model_type, relative_path
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
|
|||||||
9
main.py
9
main.py
@@ -147,6 +147,7 @@ def cuda_malloc_warning():
|
|||||||
if cuda_malloc_warning:
|
if cuda_malloc_warning:
|
||||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||||
|
|
||||||
|
|
||||||
def prompt_worker(q, server_instance):
|
def prompt_worker(q, server_instance):
|
||||||
current_time: float = 0.0
|
current_time: float = 0.0
|
||||||
cache_type = execution.CacheType.CLASSIC
|
cache_type = execution.CacheType.CLASSIC
|
||||||
@@ -236,13 +237,6 @@ def cleanup_temp():
|
|||||||
if os.path.exists(temp_dir):
|
if os.path.exists(temp_dir):
|
||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
def setup_database():
|
|
||||||
try:
|
|
||||||
from app.database.db import init_db, dependencies_available
|
|
||||||
if dependencies_available():
|
|
||||||
init_db()
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
|
||||||
|
|
||||||
def start_comfyui(asyncio_loop=None):
|
def start_comfyui(asyncio_loop=None):
|
||||||
"""
|
"""
|
||||||
@@ -272,7 +266,6 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
hook_breaker_ac10a0.restore_functions()
|
hook_breaker_ac10a0.restore_functions()
|
||||||
|
|
||||||
cuda_malloc_warning()
|
cuda_malloc_warning()
|
||||||
setup_database()
|
|
||||||
|
|
||||||
prompt_server.add_routes()
|
prompt_server.add_routes()
|
||||||
hijack_progress(prompt_server)
|
hijack_progress(prompt_server)
|
||||||
|
|||||||
1
nodes.py
1
nodes.py
@@ -2241,6 +2241,7 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_mahiro.py",
|
"nodes_mahiro.py",
|
||||||
"nodes_lt.py",
|
"nodes_lt.py",
|
||||||
"nodes_hooks.py",
|
"nodes_hooks.py",
|
||||||
|
"nodes_multigpu.py",
|
||||||
"nodes_load_3d.py",
|
"nodes_load_3d.py",
|
||||||
"nodes_cosmos.py",
|
"nodes_cosmos.py",
|
||||||
"nodes_video.py",
|
"nodes_video.py",
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.21.3
|
comfyui-frontend-package==1.21.3
|
||||||
comfyui-workflow-templates==0.1.25
|
comfyui-workflow-templates==0.1.23
|
||||||
comfyui-embedded-docs==0.2.0
|
comfyui-embedded-docs==0.2.0
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
@@ -18,9 +18,6 @@ Pillow
|
|||||||
scipy
|
scipy
|
||||||
tqdm
|
tqdm
|
||||||
psutil
|
psutil
|
||||||
alembic
|
|
||||||
SQLAlchemy
|
|
||||||
blake3
|
|
||||||
|
|
||||||
#non essential dependencies:
|
#non essential dependencies:
|
||||||
kornia>=0.7.1
|
kornia>=0.7.1
|
||||||
|
|||||||
@@ -1,253 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
from sqlalchemy import create_engine
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from app.model_processor import ModelProcessor
|
|
||||||
from app.database.models import Model, Base
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Test data constants
|
|
||||||
TEST_MODEL_TYPE = "checkpoints"
|
|
||||||
TEST_URL = "http://example.com/model.safetensors"
|
|
||||||
TEST_FILE_NAME = "model.safetensors"
|
|
||||||
TEST_EXPECTED_HASH = "abc123"
|
|
||||||
TEST_DESTINATION_PATH = "/path/to/model.safetensors"
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_model(session, file_name, model_type, hash_value, file_size=1000, source_url=None):
|
|
||||||
"""Helper to create a test model in the database."""
|
|
||||||
model = Model(path=file_name, type=model_type, hash=hash_value, file_size=file_size, source_url=source_url)
|
|
||||||
session.add(model)
|
|
||||||
session.commit()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def setup_mock_hash_calculation(model_processor, hash_value):
|
|
||||||
"""Helper to setup hash calculation mocks."""
|
|
||||||
mock_hash = MagicMock()
|
|
||||||
mock_hash.hexdigest.return_value = hash_value
|
|
||||||
return patch.object(model_processor, "_get_hasher", return_value=mock_hash)
|
|
||||||
|
|
||||||
|
|
||||||
def verify_model_in_db(session, file_name, expected_hash=None, expected_type=None):
|
|
||||||
"""Helper to verify model exists in database with correct attributes."""
|
|
||||||
db_model = session.query(Model).filter_by(path=file_name).first()
|
|
||||||
assert db_model is not None
|
|
||||||
if expected_hash:
|
|
||||||
assert db_model.hash == expected_hash
|
|
||||||
if expected_type:
|
|
||||||
assert db_model.type == expected_type
|
|
||||||
return db_model
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def db_engine():
|
|
||||||
# Configure in-memory database
|
|
||||||
engine = create_engine("sqlite:///:memory:")
|
|
||||||
Base.metadata.create_all(engine)
|
|
||||||
yield engine
|
|
||||||
Base.metadata.drop_all(engine)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def db_session(db_engine):
|
|
||||||
Session = sessionmaker(bind=db_engine)
|
|
||||||
session = Session()
|
|
||||||
yield session
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_get_relative_path():
|
|
||||||
with patch("app.model_processor.get_relative_path") as mock:
|
|
||||||
mock.side_effect = lambda path: (TEST_MODEL_TYPE, os.path.basename(path))
|
|
||||||
yield mock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_get_full_path():
|
|
||||||
with patch("app.model_processor.get_full_path") as mock:
|
|
||||||
mock.return_value = TEST_DESTINATION_PATH
|
|
||||||
yield mock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def model_processor(db_session, mock_get_relative_path, mock_get_full_path):
|
|
||||||
with patch("app.model_processor.create_session", return_value=db_session):
|
|
||||||
with patch("app.model_processor.can_create_session", return_value=True):
|
|
||||||
processor = ModelProcessor()
|
|
||||||
# Setup test state
|
|
||||||
processor.removed_files = []
|
|
||||||
processor.downloaded_files = []
|
|
||||||
processor.file_exists = {}
|
|
||||||
|
|
||||||
def mock_download_file(url, destination_path, hasher):
|
|
||||||
processor.downloaded_files.append((url, destination_path))
|
|
||||||
processor.file_exists[destination_path] = True
|
|
||||||
# Simulate writing some data to the file
|
|
||||||
test_data = b"test data"
|
|
||||||
hasher.update(test_data)
|
|
||||||
|
|
||||||
def mock_remove_file(file_path):
|
|
||||||
processor.removed_files.append(file_path)
|
|
||||||
if file_path in processor.file_exists:
|
|
||||||
del processor.file_exists[file_path]
|
|
||||||
|
|
||||||
# Setup common patches
|
|
||||||
file_exists_patch = patch.object(
|
|
||||||
processor,
|
|
||||||
"_file_exists",
|
|
||||||
side_effect=lambda path: processor.file_exists.get(path, False),
|
|
||||||
)
|
|
||||||
file_size_patch = patch.object(
|
|
||||||
processor,
|
|
||||||
"_get_file_size",
|
|
||||||
side_effect=lambda path: (
|
|
||||||
1000 if processor.file_exists.get(path, False) else 0
|
|
||||||
),
|
|
||||||
)
|
|
||||||
download_file_patch = patch.object(
|
|
||||||
processor, "_download_file", side_effect=mock_download_file
|
|
||||||
)
|
|
||||||
remove_file_patch = patch.object(
|
|
||||||
processor, "_remove_file", side_effect=mock_remove_file
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
file_exists_patch,
|
|
||||||
file_size_patch,
|
|
||||||
download_file_patch,
|
|
||||||
remove_file_patch,
|
|
||||||
):
|
|
||||||
yield processor
|
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_downloaded_invalid_extension(model_processor):
|
|
||||||
# Ensure that an unsupported file extension raises an error to prevent unsafe file downloads
|
|
||||||
with pytest.raises(ValueError, match="Unsupported unsafe file for download"):
|
|
||||||
model_processor.ensure_downloaded(TEST_MODEL_TYPE, TEST_URL, "model.exe")
|
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_downloaded_existing_file_with_hash(model_processor, db_session):
|
|
||||||
# Ensure that a file with the same hash but from a different source is not downloaded again
|
|
||||||
SOURCE_URL = "https://example.com/other.sft"
|
|
||||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH, source_url=SOURCE_URL)
|
|
||||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
|
||||||
|
|
||||||
result = model_processor.ensure_downloaded(
|
|
||||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == TEST_DESTINATION_PATH
|
|
||||||
model = verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
|
||||||
assert model.source_url == SOURCE_URL # Ensure the source URL is not overwritten
|
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_downloaded_existing_file_hash_mismatch(model_processor, db_session):
|
|
||||||
# Ensure that a file with a different hash raises an error
|
|
||||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, "different_hash")
|
|
||||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="File .* exists with hash .* but expected .*"):
|
|
||||||
model_processor.ensure_downloaded(
|
|
||||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_downloaded_new_file(model_processor, db_session):
|
|
||||||
# Ensure that a new file is downloaded
|
|
||||||
model_processor.file_exists[TEST_DESTINATION_PATH] = False
|
|
||||||
|
|
||||||
with setup_mock_hash_calculation(model_processor, TEST_EXPECTED_HASH):
|
|
||||||
result = model_processor.ensure_downloaded(
|
|
||||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == TEST_DESTINATION_PATH
|
|
||||||
assert len(model_processor.downloaded_files) == 1
|
|
||||||
assert model_processor.downloaded_files[0] == (TEST_URL, TEST_DESTINATION_PATH)
|
|
||||||
assert model_processor.file_exists[TEST_DESTINATION_PATH]
|
|
||||||
verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_downloaded_hash_mismatch(model_processor, db_session):
|
|
||||||
# Ensure that download that results in a different hash raises an error
|
|
||||||
model_processor.file_exists[TEST_DESTINATION_PATH] = False
|
|
||||||
|
|
||||||
with setup_mock_hash_calculation(model_processor, "different_hash"):
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match="Downloaded file hash .* does not match expected hash .*",
|
|
||||||
):
|
|
||||||
model_processor.ensure_downloaded(
|
|
||||||
TEST_MODEL_TYPE,
|
|
||||||
TEST_URL,
|
|
||||||
TEST_FILE_NAME,
|
|
||||||
TEST_EXPECTED_HASH,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(model_processor.removed_files) == 1
|
|
||||||
assert model_processor.removed_files[0] == TEST_DESTINATION_PATH
|
|
||||||
assert TEST_DESTINATION_PATH not in model_processor.file_exists
|
|
||||||
assert db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_file_without_hash(model_processor, db_session):
|
|
||||||
# Test processing file without provided hash
|
|
||||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
|
||||||
|
|
||||||
with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH):
|
|
||||||
result = model_processor.process_file(TEST_DESTINATION_PATH)
|
|
||||||
assert result is not None
|
|
||||||
assert result.hash == TEST_EXPECTED_HASH
|
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_model_by_hash(model_processor, db_session):
|
|
||||||
# Test retrieving model by hash
|
|
||||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
|
||||||
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH)
|
|
||||||
assert result is not None
|
|
||||||
assert result.hash == TEST_EXPECTED_HASH
|
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_model_by_hash_and_type(model_processor, db_session):
|
|
||||||
# Test retrieving model by hash and type
|
|
||||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
|
||||||
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
|
||||||
assert result is not None
|
|
||||||
assert result.hash == TEST_EXPECTED_HASH
|
|
||||||
assert result.type == TEST_MODEL_TYPE
|
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_hash(model_processor, db_session):
|
|
||||||
# Test retrieving hash for existing model
|
|
||||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
|
||||||
with patch.object(
|
|
||||||
model_processor,
|
|
||||||
"_validate_path",
|
|
||||||
return_value=(TEST_MODEL_TYPE, TEST_FILE_NAME),
|
|
||||||
):
|
|
||||||
result = model_processor.retrieve_hash(TEST_DESTINATION_PATH, TEST_MODEL_TYPE)
|
|
||||||
assert result == TEST_EXPECTED_HASH
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_file_extension_valid_extensions(model_processor):
|
|
||||||
# Test all valid file extensions
|
|
||||||
valid_extensions = [".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"]
|
|
||||||
for ext in valid_extensions:
|
|
||||||
model_processor._validate_file_extension(f"test{ext}") # Should not raise
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_file_existing_without_source_url(model_processor, db_session):
|
|
||||||
# Test processing an existing file that needs its source URL updated
|
|
||||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
|
||||||
|
|
||||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
|
||||||
result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL)
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result.hash == TEST_EXPECTED_HASH
|
|
||||||
assert result.source_url == TEST_URL
|
|
||||||
|
|
||||||
db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first()
|
|
||||||
assert db_model.source_url == TEST_URL
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# The path to the requirements.txt file
|
|
||||||
requirements_path = Path(__file__).parents[1] / "requirements.txt"
|
|
||||||
|
|
||||||
|
|
||||||
def get_missing_requirements_message():
|
|
||||||
"""The warning message to display when a package is missing."""
|
|
||||||
|
|
||||||
extra = ""
|
|
||||||
if sys.flags.no_user_site:
|
|
||||||
extra = "-s "
|
|
||||||
return f"""
|
|
||||||
Please install the updated requirements.txt file by running:
|
|
||||||
{sys.executable} {extra}-m pip install -r {requirements_path}
|
|
||||||
|
|
||||||
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem.
|
|
||||||
""".strip()
|
|
||||||
Reference in New Issue
Block a user