Compare commits
2 Commits
v0.3.11
...
model_mana
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fde9fdddff | ||
|
|
7bf381bc9e |
2
.github/workflows/test-build.yml
vendored
2
.github/workflows/test-build.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
|||||||
119
alembic.ini
Normal file
119
alembic.ini
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
# A generic, single database configuration.
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts
|
||||||
|
# Use forward slashes (/) also on windows to provide an os agnostic path
|
||||||
|
script_location = alembic_db
|
||||||
|
|
||||||
|
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||||
|
# Uncomment the line below if you want the files to be prepended with date and time
|
||||||
|
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||||
|
# for all available tokens
|
||||||
|
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
|
# defaults to the current working directory.
|
||||||
|
prepend_sys_path = .
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
# as well as the filename.
|
||||||
|
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||||
|
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||||
|
# string value is passed to ZoneInfo()
|
||||||
|
# leave blank for localtime
|
||||||
|
# timezone =
|
||||||
|
|
||||||
|
# max length of characters to apply to the "slug" field
|
||||||
|
# truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during
|
||||||
|
# the 'revision' command, regardless of autogenerate
|
||||||
|
# revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without
|
||||||
|
# a source .py file to be detected as revisions in the
|
||||||
|
# versions/ directory
|
||||||
|
# sourceless = false
|
||||||
|
|
||||||
|
# version location specification; This defaults
|
||||||
|
# to alembic_db/versions. When using multiple version
|
||||||
|
# directories, initial revisions must be specified with --version-path.
|
||||||
|
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||||
|
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
|
||||||
|
|
||||||
|
# version path separator; As mentioned above, this is the character used to split
|
||||||
|
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||||
|
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||||
|
# Valid values for version_path_separator are:
|
||||||
|
#
|
||||||
|
# version_path_separator = :
|
||||||
|
# version_path_separator = ;
|
||||||
|
# version_path_separator = space
|
||||||
|
# version_path_separator = newline
|
||||||
|
#
|
||||||
|
# Use os.pathsep. Default configuration used for new projects.
|
||||||
|
version_path_separator = os
|
||||||
|
|
||||||
|
# set to 'true' to search source files recursively
|
||||||
|
# in each "version_locations" directory
|
||||||
|
# new in Alembic version 1.10
|
||||||
|
# recursive_version_locations = false
|
||||||
|
|
||||||
|
# the output encoding used when revision files
|
||||||
|
# are written from script.py.mako
|
||||||
|
# output_encoding = utf-8
|
||||||
|
|
||||||
|
sqlalchemy.url = sqlite:///user/comfyui.db
|
||||||
|
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
# post_write_hooks defines scripts or Python functions that are run
|
||||||
|
# on newly generated revision scripts. See the documentation for further
|
||||||
|
# detail and examples
|
||||||
|
|
||||||
|
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||||
|
# hooks = black
|
||||||
|
# black.type = console_scripts
|
||||||
|
# black.entrypoint = black
|
||||||
|
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||||
|
# hooks = ruff
|
||||||
|
# ruff.type = exec
|
||||||
|
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||||
|
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARNING
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARNING
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
3
alembic_db/README.md
Normal file
3
alembic_db/README.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
## Generate new revision
|
||||||
|
1. Update models in `/app/database/models.py`
|
||||||
|
2. Run `alembic revision --autogenerate -m "{your message}"`
|
||||||
75
alembic_db/env.py
Normal file
75
alembic_db/env.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from sqlalchemy import engine_from_config
|
||||||
|
from sqlalchemy import pool
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
|
||||||
|
# this is the Alembic Config object, which provides
|
||||||
|
# access to the values within the .ini file in use.
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Interpret the config file for Python logging.
|
||||||
|
# This line sets up loggers basically.
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
from app.database.models import Base
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
# other values from the config, defined by the needs of env.py,
|
||||||
|
# can be acquired:
|
||||||
|
# my_important_option = config.get_main_option("my_important_option")
|
||||||
|
# ... etc.
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
|
This configures the context with just a URL
|
||||||
|
and not an Engine, though an Engine is acceptable
|
||||||
|
here as well. By skipping the Engine creation
|
||||||
|
we don't even need a DBAPI to be available.
|
||||||
|
|
||||||
|
Calls to context.execute() here emit the given string to the
|
||||||
|
script output.
|
||||||
|
|
||||||
|
"""
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode.
|
||||||
|
|
||||||
|
In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
|
||||||
|
"""
|
||||||
|
connectable = engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(
|
||||||
|
connection=connection, target_metadata=target_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
28
alembic_db/script.py.mako
Normal file
28
alembic_db/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
58
alembic_db/versions/2fb22c4fff36_init.py
Normal file
58
alembic_db/versions/2fb22c4fff36_init.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""init
|
||||||
|
|
||||||
|
Revision ID: 2fb22c4fff36
|
||||||
|
Revises:
|
||||||
|
Create Date: 2025-03-27 19:00:47.686079
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '2fb22c4fff36'
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('model',
|
||||||
|
sa.Column('type', sa.Text(), nullable=False),
|
||||||
|
sa.Column('path', sa.Text(), nullable=False),
|
||||||
|
sa.Column('title', sa.Text(), nullable=True),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('architecture', sa.Text(), nullable=True),
|
||||||
|
sa.Column('hash', sa.Text(), nullable=True),
|
||||||
|
sa.Column('source_url', sa.Text(), nullable=True),
|
||||||
|
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('type', 'path')
|
||||||
|
)
|
||||||
|
op.create_table('tag',
|
||||||
|
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column('name', sa.Text(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('name')
|
||||||
|
)
|
||||||
|
op.create_table('model_tag',
|
||||||
|
sa.Column('model_type', sa.Text(), nullable=False),
|
||||||
|
sa.Column('model_path', sa.Text(), nullable=False),
|
||||||
|
sa.Column('tag_id', sa.Integer(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['model_type', 'model_path'], ['model.type', 'model.path'], ondelete='CASCADE'),
|
||||||
|
sa.ForeignKeyConstraint(['tag_id'], ['tag.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('model_type', 'model_path', 'tag_id')
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table('model_tag')
|
||||||
|
op.drop_table('tag')
|
||||||
|
op.drop_table('model')
|
||||||
|
# ### end Alembic commands ###
|
||||||
118
app/database/db.py
Normal file
118
app/database/db.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
from app.database.models import Tag
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
try:
|
||||||
|
import alembic
|
||||||
|
import sqlalchemy
|
||||||
|
except ImportError as e:
|
||||||
|
req_path = os.path.abspath(
|
||||||
|
os.path.join(os.path.dirname(__file__), "../..", "requirements.txt")
|
||||||
|
)
|
||||||
|
logging.error(
|
||||||
|
f"\n\n********** ERROR ***********\n\nRequirements are not installed ({e}). Please install the requirements.txt file by running:\n{sys.executable} -s -m pip install -r {req_path}\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n"
|
||||||
|
)
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
from alembic import command
|
||||||
|
from alembic.config import Config
|
||||||
|
from alembic.runtime.migration import MigrationContext
|
||||||
|
from alembic.script import ScriptDirectory
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
Session = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_alembic_config():
|
||||||
|
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
||||||
|
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||||
|
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||||
|
|
||||||
|
config = Config(config_path)
|
||||||
|
config.set_main_option("script_location", scripts_path)
|
||||||
|
config.set_main_option("sqlalchemy.url", args.database_url)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_path():
|
||||||
|
url = args.database_url
|
||||||
|
if url.startswith("sqlite:///"):
|
||||||
|
return url.split("///")[1]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||||
|
|
||||||
|
|
||||||
|
def init_db():
|
||||||
|
db_url = args.database_url
|
||||||
|
logging.debug(f"Database URL: {db_url}")
|
||||||
|
|
||||||
|
config = get_alembic_config()
|
||||||
|
|
||||||
|
# Check if we need to upgrade
|
||||||
|
engine = create_engine(db_url)
|
||||||
|
conn = engine.connect()
|
||||||
|
|
||||||
|
context = MigrationContext.configure(conn)
|
||||||
|
current_rev = context.get_current_revision()
|
||||||
|
|
||||||
|
script = ScriptDirectory.from_config(config)
|
||||||
|
target_rev = script.get_current_head()
|
||||||
|
|
||||||
|
if current_rev != target_rev:
|
||||||
|
# Backup the database pre upgrade
|
||||||
|
db_path = get_db_path()
|
||||||
|
backup_path = db_path + ".bkp"
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
shutil.copy(db_path, backup_path)
|
||||||
|
else:
|
||||||
|
backup_path = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
command.upgrade(config, target_rev)
|
||||||
|
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||||
|
except Exception as e:
|
||||||
|
if backup_path:
|
||||||
|
# Restore the database from backup if upgrade fails
|
||||||
|
shutil.copy(backup_path, db_path)
|
||||||
|
os.remove(backup_path)
|
||||||
|
logging.error(f"Error upgrading database: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
global Session
|
||||||
|
Session = sessionmaker(bind=engine)
|
||||||
|
|
||||||
|
if not current_rev:
|
||||||
|
# Init db, populate models
|
||||||
|
from app.model_processor import model_processor
|
||||||
|
|
||||||
|
session = create_session()
|
||||||
|
model_processor.populate_models(session)
|
||||||
|
|
||||||
|
# populate tags
|
||||||
|
tags = (
|
||||||
|
"character",
|
||||||
|
"style",
|
||||||
|
"concept",
|
||||||
|
"clothing",
|
||||||
|
"pose",
|
||||||
|
"background",
|
||||||
|
"vehicle",
|
||||||
|
"object",
|
||||||
|
"animal",
|
||||||
|
"action",
|
||||||
|
)
|
||||||
|
for tag in tags:
|
||||||
|
session.add(Tag(name=tag))
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def can_create_session():
|
||||||
|
return Session is not None
|
||||||
|
|
||||||
|
def create_session():
|
||||||
|
return Session()
|
||||||
76
app/database/models.py
Normal file
76
app/database/models.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Integer,
|
||||||
|
Text,
|
||||||
|
DateTime,
|
||||||
|
Table,
|
||||||
|
ForeignKeyConstraint,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import relationship, declarative_base
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
def to_dict(obj):
|
||||||
|
fields = obj.__table__.columns.keys()
|
||||||
|
return {
|
||||||
|
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||||
|
for field in fields
|
||||||
|
if (val := getattr(obj, field))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ModelTag = Table(
|
||||||
|
"model_tag",
|
||||||
|
Base.metadata,
|
||||||
|
Column(
|
||||||
|
"model_type",
|
||||||
|
Text,
|
||||||
|
primary_key=True,
|
||||||
|
),
|
||||||
|
Column(
|
||||||
|
"model_path",
|
||||||
|
Text,
|
||||||
|
primary_key=True,
|
||||||
|
),
|
||||||
|
Column("tag_id", Integer, primary_key=True),
|
||||||
|
ForeignKeyConstraint(
|
||||||
|
["model_type", "model_path"], ["model.type", "model.path"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
ForeignKeyConstraint(["tag_id"], ["tag.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(Base):
|
||||||
|
__tablename__ = "model"
|
||||||
|
|
||||||
|
type = Column(Text, primary_key=True)
|
||||||
|
path = Column(Text, primary_key=True)
|
||||||
|
title = Column(Text)
|
||||||
|
description = Column(Text)
|
||||||
|
architecture = Column(Text)
|
||||||
|
hash = Column(Text)
|
||||||
|
source_url = Column(Text)
|
||||||
|
date_added = Column(DateTime, server_default=func.now())
|
||||||
|
|
||||||
|
# Relationship with tags
|
||||||
|
tags = relationship("Tag", secondary=ModelTag, back_populates="models")
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
dict = to_dict(self)
|
||||||
|
dict["tags"] = [tag.to_dict() for tag in self.tags]
|
||||||
|
return dict
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(Base):
|
||||||
|
__tablename__ = "tag"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
name = Column(Text, nullable=False, unique=True)
|
||||||
|
|
||||||
|
# Relationship with models
|
||||||
|
models = relationship("Model", secondary=ModelTag, back_populates="tags")
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return to_dict(self)
|
||||||
@@ -1,19 +1,30 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import base64
|
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
from app.database.db import create_session
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import glob
|
|
||||||
import comfy.utils
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
|
from folder_paths import map_legacy, filter_files_extensions, get_full_path
|
||||||
|
from app.database.models import Tag, Model
|
||||||
|
from app.model_processor import get_model_previews, model_processor
|
||||||
|
from utils.web import dumps
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
|
import sqlalchemy.exc
|
||||||
|
|
||||||
|
|
||||||
|
def bad_request(message: str):
|
||||||
|
return web.json_response({"error": message}, status=400)
|
||||||
|
|
||||||
|
def missing_field(field: str):
|
||||||
|
return bad_request(f"{field} is required")
|
||||||
|
|
||||||
|
def not_found(message: str):
|
||||||
|
return web.json_response({"error": message + " not found"}, status=404)
|
||||||
|
|
||||||
class ModelFileManager:
|
class ModelFileManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
||||||
@@ -62,7 +73,7 @@ class ModelFileManager:
|
|||||||
folder = folders[0][path_index]
|
folder = folders[0][path_index]
|
||||||
full_filename = os.path.join(folder, filename)
|
full_filename = os.path.join(folder, filename)
|
||||||
|
|
||||||
previews = self.get_model_previews(full_filename)
|
previews = get_model_previews(full_filename)
|
||||||
default_preview = previews[0] if len(previews) > 0 else None
|
default_preview = previews[0] if len(previews) > 0 else None
|
||||||
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
@@ -76,6 +87,183 @@ class ModelFileManager:
|
|||||||
except:
|
except:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
|
@routes.get("/v2/models")
|
||||||
|
async def get_models(request):
|
||||||
|
with create_session() as session:
|
||||||
|
model_path = request.query.get("path", None)
|
||||||
|
model_type = request.query.get("type", None)
|
||||||
|
query = session.query(Model).options(joinedload(Model.tags))
|
||||||
|
if model_path:
|
||||||
|
query = query.filter(Model.path == model_path)
|
||||||
|
if model_type:
|
||||||
|
query = query.filter(Model.type == model_type)
|
||||||
|
models = query.all()
|
||||||
|
if model_path and model_type:
|
||||||
|
if len(models) == 0:
|
||||||
|
return not_found("Model")
|
||||||
|
return web.json_response(models[0].to_dict(), dumps=dumps)
|
||||||
|
|
||||||
|
return web.json_response([model.to_dict() for model in models], dumps=dumps)
|
||||||
|
|
||||||
|
@routes.post("/v2/models")
|
||||||
|
async def add_model(request):
|
||||||
|
with create_session() as session:
|
||||||
|
data = await request.json()
|
||||||
|
model_type = data.get("type", None)
|
||||||
|
model_path = data.get("path", None)
|
||||||
|
|
||||||
|
if not model_type:
|
||||||
|
return missing_field("type")
|
||||||
|
if not model_path:
|
||||||
|
return missing_field("path")
|
||||||
|
|
||||||
|
tags = data.pop("tags", [])
|
||||||
|
fields = Model.metadata.tables["model"].columns.keys()
|
||||||
|
|
||||||
|
# Validate keys are valid model fields
|
||||||
|
for key in data.keys():
|
||||||
|
if key not in fields:
|
||||||
|
return bad_request(f"Invalid field: {key}")
|
||||||
|
|
||||||
|
# Validate file exists
|
||||||
|
if not get_full_path(model_type, model_path):
|
||||||
|
return not_found(f"File '{model_type}/{model_path}'")
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
for field in fields:
|
||||||
|
if field in data:
|
||||||
|
setattr(model, field, data[field])
|
||||||
|
|
||||||
|
model.tags = session.query(Tag).filter(Tag.id.in_(tags)).all()
|
||||||
|
for tag in tags:
|
||||||
|
if tag not in [t.id for t in model.tags]:
|
||||||
|
return not_found(f"Tag '{tag}'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
session.add(model)
|
||||||
|
session.commit()
|
||||||
|
except sqlalchemy.exc.IntegrityError as e:
|
||||||
|
session.rollback()
|
||||||
|
return bad_request(e.orig.args[0])
|
||||||
|
|
||||||
|
model_processor.run()
|
||||||
|
|
||||||
|
return web.json_response(model.to_dict(), dumps=dumps)
|
||||||
|
|
||||||
|
@routes.delete("/v2/models")
|
||||||
|
async def delete_model(request):
|
||||||
|
with create_session() as session:
|
||||||
|
model_path = request.query.get("path", None)
|
||||||
|
model_type = request.query.get("type", None)
|
||||||
|
if not model_path:
|
||||||
|
return missing_field("path")
|
||||||
|
if not model_type:
|
||||||
|
return missing_field("type")
|
||||||
|
|
||||||
|
full_path = get_full_path(model_type, model_path)
|
||||||
|
if full_path:
|
||||||
|
return bad_request("Model file exists, please delete the file before deleting the model record.")
|
||||||
|
|
||||||
|
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
|
||||||
|
if not model:
|
||||||
|
return not_found("Model")
|
||||||
|
session.delete(model)
|
||||||
|
session.commit()
|
||||||
|
return web.Response(status=204)
|
||||||
|
|
||||||
|
@routes.get("/v2/tags")
|
||||||
|
async def get_tags(request):
|
||||||
|
with create_session() as session:
|
||||||
|
tags = session.query(Tag).all()
|
||||||
|
return web.json_response(
|
||||||
|
[{"id": tag.id, "name": tag.name} for tag in tags]
|
||||||
|
)
|
||||||
|
|
||||||
|
@routes.post("/v2/tags")
|
||||||
|
async def create_tag(request):
|
||||||
|
with create_session() as session:
|
||||||
|
data = await request.json()
|
||||||
|
name = data.get("name", None)
|
||||||
|
if not name:
|
||||||
|
return missing_field("name")
|
||||||
|
tag = Tag(name=name)
|
||||||
|
session.add(tag)
|
||||||
|
session.commit()
|
||||||
|
return web.json_response({"id": tag.id, "name": tag.name})
|
||||||
|
|
||||||
|
@routes.delete("/v2/tags")
|
||||||
|
async def delete_tag(request):
|
||||||
|
with create_session() as session:
|
||||||
|
tag_id = request.query.get("id", None)
|
||||||
|
if not tag_id:
|
||||||
|
return missing_field("id")
|
||||||
|
tag = session.query(Tag).filter(Tag.id == tag_id).first()
|
||||||
|
if not tag:
|
||||||
|
return not_found("Tag")
|
||||||
|
session.delete(tag)
|
||||||
|
session.commit()
|
||||||
|
return web.Response(status=204)
|
||||||
|
|
||||||
|
@routes.post("/v2/models/tags")
|
||||||
|
async def add_model_tag(request):
|
||||||
|
with create_session() as session:
|
||||||
|
data = await request.json()
|
||||||
|
tag_id = data.get("tag", None)
|
||||||
|
model_path = data.get("path", None)
|
||||||
|
model_type = data.get("type", None)
|
||||||
|
|
||||||
|
if tag_id is None:
|
||||||
|
return missing_field("tag")
|
||||||
|
if model_path is None:
|
||||||
|
return missing_field("path")
|
||||||
|
if model_type is None:
|
||||||
|
return missing_field("type")
|
||||||
|
|
||||||
|
try:
|
||||||
|
tag_id = int(tag_id)
|
||||||
|
except ValueError:
|
||||||
|
return bad_request("Invalid tag id")
|
||||||
|
|
||||||
|
tag = session.query(Tag).filter(Tag.id == tag_id).first()
|
||||||
|
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
|
||||||
|
if not model:
|
||||||
|
return not_found("Model")
|
||||||
|
model.tags.append(tag)
|
||||||
|
session.commit()
|
||||||
|
return web.json_response(model.to_dict(), dumps=dumps)
|
||||||
|
|
||||||
|
@routes.delete("/v2/models/tags")
|
||||||
|
async def delete_model_tag(request):
|
||||||
|
with create_session() as session:
|
||||||
|
tag_id = request.query.get("tag", None)
|
||||||
|
model_path = request.query.get("path", None)
|
||||||
|
model_type = request.query.get("type", None)
|
||||||
|
|
||||||
|
if tag_id is None:
|
||||||
|
return missing_field("tag")
|
||||||
|
if model_path is None:
|
||||||
|
return missing_field("path")
|
||||||
|
if model_type is None:
|
||||||
|
return missing_field("type")
|
||||||
|
|
||||||
|
try:
|
||||||
|
tag_id = int(tag_id)
|
||||||
|
except ValueError:
|
||||||
|
return bad_request("Invalid tag id")
|
||||||
|
|
||||||
|
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
|
||||||
|
if not model:
|
||||||
|
return not_found("Model")
|
||||||
|
model.tags = [tag for tag in model.tags if tag.id != tag_id]
|
||||||
|
session.commit()
|
||||||
|
return web.Response(status=204)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@routes.get("/v2/models/missing")
|
||||||
|
async def get_missing_models(request):
|
||||||
|
return web.json_response(model_processor.missing_models)
|
||||||
|
|
||||||
def get_model_file_list(self, folder_name: str):
|
def get_model_file_list(self, folder_name: str):
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||||
@@ -146,39 +334,5 @@ class ModelFileManager:
|
|||||||
|
|
||||||
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
||||||
|
|
||||||
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
|
||||||
dirname = os.path.dirname(filepath)
|
|
||||||
|
|
||||||
if not os.path.exists(dirname):
|
|
||||||
return []
|
|
||||||
|
|
||||||
basename = os.path.splitext(filepath)[0]
|
|
||||||
match_files = glob.glob(f"{basename}.*", recursive=False)
|
|
||||||
image_files = filter_files_content_types(match_files, "image")
|
|
||||||
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
|
|
||||||
safetensors_metadata = {}
|
|
||||||
|
|
||||||
result: list[str | BytesIO] = []
|
|
||||||
|
|
||||||
for filename in image_files:
|
|
||||||
_basename = os.path.splitext(filename)[0]
|
|
||||||
if _basename == basename:
|
|
||||||
result.append(filename)
|
|
||||||
if _basename == f"{basename}.preview":
|
|
||||||
result.append(filename)
|
|
||||||
|
|
||||||
if safetensors_file:
|
|
||||||
safetensors_filepath = os.path.join(dirname, safetensors_file)
|
|
||||||
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
|
|
||||||
if header:
|
|
||||||
safetensors_metadata = json.loads(header)
|
|
||||||
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
|
|
||||||
if safetensors_images:
|
|
||||||
safetensors_images = json.loads(safetensors_images)
|
|
||||||
for image in safetensors_images:
|
|
||||||
result.append(BytesIO(base64.b64decode(image)))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
|
|||||||
263
app/model_processor.py
Normal file
263
app/model_processor.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
import base64
|
||||||
|
from datetime import datetime
|
||||||
|
import glob
|
||||||
|
import hashlib
|
||||||
|
from io import BytesIO
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import comfy.utils
|
||||||
|
from app.database.models import Model
|
||||||
|
from app.database.db import create_session
|
||||||
|
from comfy.cli_args import args
|
||||||
|
from folder_paths import (
|
||||||
|
filter_files_content_types,
|
||||||
|
get_full_path,
|
||||||
|
folder_names_and_paths,
|
||||||
|
get_filename_list,
|
||||||
|
)
|
||||||
|
from PIL import Image
|
||||||
|
from urllib import request
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_previews(
|
||||||
|
filepath: str, check_metadata: bool = True
|
||||||
|
) -> list[str | BytesIO]:
|
||||||
|
dirname = os.path.dirname(filepath)
|
||||||
|
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
return []
|
||||||
|
|
||||||
|
basename = os.path.splitext(filepath)[0]
|
||||||
|
match_files = glob.glob(f"{basename}.*", recursive=False)
|
||||||
|
image_files = filter_files_content_types(match_files, "image")
|
||||||
|
|
||||||
|
result: list[str | BytesIO] = []
|
||||||
|
|
||||||
|
for filename in image_files:
|
||||||
|
_basename = os.path.splitext(filename)[0]
|
||||||
|
if _basename == basename:
|
||||||
|
result.append(filename)
|
||||||
|
if _basename == f"{basename}.preview":
|
||||||
|
result.append(filename)
|
||||||
|
|
||||||
|
if not check_metadata:
|
||||||
|
return result
|
||||||
|
|
||||||
|
safetensors_file = next(
|
||||||
|
filter(lambda x: x.endswith(".safetensors"), match_files), None
|
||||||
|
)
|
||||||
|
safetensors_metadata = {}
|
||||||
|
|
||||||
|
if safetensors_file:
|
||||||
|
safetensors_filepath = os.path.join(dirname, safetensors_file)
|
||||||
|
header = comfy.utils.safetensors_header(
|
||||||
|
safetensors_filepath, max_size=8 * 1024 * 1024
|
||||||
|
)
|
||||||
|
if header:
|
||||||
|
safetensors_metadata = json.loads(header)
|
||||||
|
safetensors_images = safetensors_metadata.get("__metadata__", {}).get(
|
||||||
|
"ssmd_cover_images", None
|
||||||
|
)
|
||||||
|
if safetensors_images:
|
||||||
|
safetensors_images = json.loads(safetensors_images)
|
||||||
|
for image in safetensors_images:
|
||||||
|
result.append(BytesIO(base64.b64decode(image)))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProcessor:
|
||||||
|
def __init__(self):
|
||||||
|
self._thread = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._run = False
|
||||||
|
self.missing_models = []
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
if args.disable_model_processing:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._thread is None:
|
||||||
|
# Lock to prevent multiple threads from starting
|
||||||
|
with self._lock:
|
||||||
|
self._run = True
|
||||||
|
if self._thread is None:
|
||||||
|
self._thread = threading.Thread(target=self._process_models)
|
||||||
|
self._thread.daemon = True
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def populate_models(self, session):
|
||||||
|
# Ensure database state matches filesystem
|
||||||
|
|
||||||
|
existing_models = session.query(Model).all()
|
||||||
|
|
||||||
|
for folder_name in folder_names_and_paths.keys():
|
||||||
|
if folder_name == "custom_nodes" or folder_name == "configs":
|
||||||
|
continue
|
||||||
|
seen = set()
|
||||||
|
files = get_filename_list(folder_name)
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
if file in seen:
|
||||||
|
logging.warning(f"Skipping duplicate named model: {file}")
|
||||||
|
continue
|
||||||
|
seen.add(file)
|
||||||
|
|
||||||
|
existing_model = None
|
||||||
|
for model in existing_models:
|
||||||
|
if model.path == file and model.type == folder_name:
|
||||||
|
existing_model = model
|
||||||
|
break
|
||||||
|
|
||||||
|
if existing_model:
|
||||||
|
# Model already exists in db, remove from list and skip
|
||||||
|
existing_models.remove(existing_model)
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_path = get_full_path(folder_name, file)
|
||||||
|
|
||||||
|
model = Model(
|
||||||
|
path=file,
|
||||||
|
type=folder_name,
|
||||||
|
date_added=datetime.fromtimestamp(os.path.getctime(file_path)),
|
||||||
|
)
|
||||||
|
session.add(model)
|
||||||
|
|
||||||
|
for model in existing_models:
|
||||||
|
if not get_full_path(model.type, model.path):
|
||||||
|
logging.warning(f"Model {model.path} not found")
|
||||||
|
self.missing_models.append({"type": model.type, "path": model.path})
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def _get_models(self, session):
|
||||||
|
models = session.query(Model).filter(Model.hash == None).all()
|
||||||
|
return models
|
||||||
|
|
||||||
|
def _process_file(self, model_path):
|
||||||
|
is_safetensors = model_path.endswith(".safetensors")
|
||||||
|
metadata = {}
|
||||||
|
h = hashlib.sha256()
|
||||||
|
|
||||||
|
with open(model_path, "rb", buffering=0) as f:
|
||||||
|
if is_safetensors:
|
||||||
|
# Read header length (8 bytes)
|
||||||
|
header_size_bytes = f.read(8)
|
||||||
|
header_len = int.from_bytes(header_size_bytes, "little")
|
||||||
|
h.update(header_size_bytes)
|
||||||
|
|
||||||
|
# Read header
|
||||||
|
header_bytes = f.read(header_len)
|
||||||
|
h.update(header_bytes)
|
||||||
|
try:
|
||||||
|
metadata = json.loads(header_bytes)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Read rest of file
|
||||||
|
b = bytearray(128 * 1024)
|
||||||
|
mv = memoryview(b)
|
||||||
|
while n := f.readinto(mv):
|
||||||
|
h.update(mv[:n])
|
||||||
|
|
||||||
|
return h.hexdigest(), metadata
|
||||||
|
|
||||||
|
def _populate_info(self, model, metadata):
|
||||||
|
model.title = metadata.get("modelspec.title", None)
|
||||||
|
model.description = metadata.get("modelspec.description", None)
|
||||||
|
model.architecture = metadata.get("modelspec.architecture", None)
|
||||||
|
|
||||||
|
def _extract_image(self, model_path, metadata):
|
||||||
|
# check if image already exists
|
||||||
|
if len(get_model_previews(model_path, check_metadata=False)) > 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
image_path = os.path.splitext(model_path)[0] + ".webp"
|
||||||
|
if os.path.exists(image_path):
|
||||||
|
return
|
||||||
|
|
||||||
|
cover_images = metadata.get("ssmd_cover_images", None)
|
||||||
|
image = None
|
||||||
|
if cover_images:
|
||||||
|
try:
|
||||||
|
cover_images = json.loads(cover_images)
|
||||||
|
if len(cover_images) > 0:
|
||||||
|
image_data = cover_images[0]
|
||||||
|
image = Image.open(BytesIO(base64.b64decode(image_data)))
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(
|
||||||
|
f"Error extracting cover image for model {model_path}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not image:
|
||||||
|
thumbnail = metadata.get("modelspec.thumbnail", None)
|
||||||
|
if thumbnail:
|
||||||
|
try:
|
||||||
|
response = request.urlopen(thumbnail)
|
||||||
|
image = Image.open(response)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(
|
||||||
|
f"Error extracting thumbnail for model {model_path}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if image:
|
||||||
|
image.thumbnail((512, 512))
|
||||||
|
image.save(image_path)
|
||||||
|
image.close()
|
||||||
|
|
||||||
|
def _process_models(self):
|
||||||
|
with create_session() as session:
|
||||||
|
checked = set()
|
||||||
|
self.populate_models(session)
|
||||||
|
|
||||||
|
while self._run:
|
||||||
|
self._run = False
|
||||||
|
|
||||||
|
models = self._get_models(session)
|
||||||
|
|
||||||
|
if len(models) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
# prevent looping on the same model if it crashes
|
||||||
|
if model.path in checked:
|
||||||
|
continue
|
||||||
|
|
||||||
|
checked.add(model.path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
time.sleep(0)
|
||||||
|
now = time.time()
|
||||||
|
model_path = get_full_path(model.type, model.path)
|
||||||
|
|
||||||
|
if not model_path:
|
||||||
|
logging.warning(f"Model {model.path} not found")
|
||||||
|
self.missing_models.append(model.path)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logging.debug(f"Processing model {model_path}")
|
||||||
|
hash, header = self._process_file(model_path)
|
||||||
|
logging.debug(
|
||||||
|
f"Processed model {model_path} in {time.time() - now} seconds"
|
||||||
|
)
|
||||||
|
model.hash = hash
|
||||||
|
|
||||||
|
if header:
|
||||||
|
metadata = header.get("__metadata__", None)
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
self._populate_info(model, metadata)
|
||||||
|
self._extract_image(model_path, metadata)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error processing model {model.path}: {e}")
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._thread = None
|
||||||
|
|
||||||
|
|
||||||
|
model_processor = ModelProcessor()
|
||||||
@@ -178,6 +178,12 @@ parser.add_argument(
|
|||||||
|
|
||||||
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
||||||
|
|
||||||
|
database_default_path = os.path.abspath(
|
||||||
|
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||||
|
)
|
||||||
|
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||||
|
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -168,19 +168,15 @@ class Attention(nn.Module):
|
|||||||
k = self.to_k[1](k)
|
k = self.to_k[1](k)
|
||||||
v = self.to_v[1](v)
|
v = self.to_v[1](v)
|
||||||
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||||
# apply_rotary_pos_emb inlined
|
q = apply_rotary_pos_emb(q, rope_emb)
|
||||||
q_shape = q.shape
|
k = apply_rotary_pos_emb(k, rope_emb)
|
||||||
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
|
||||||
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
|
|
||||||
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
|
|
||||||
|
|
||||||
# apply_rotary_pos_emb inlined
|
|
||||||
k_shape = k.shape
|
|
||||||
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
|
||||||
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
|
|
||||||
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
|
|
||||||
return q, k, v
|
return q, k, v
|
||||||
|
|
||||||
|
def cal_attn(self, q, k, v, mask=None):
|
||||||
|
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||||
|
out = rearrange(out, " b n s c -> s b (n c)")
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@@ -195,10 +191,7 @@ class Attention(nn.Module):
|
|||||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||||
"""
|
"""
|
||||||
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
return self.cal_attn(q, k, v, mask)
|
||||||
del q, k, v
|
|
||||||
out = rearrange(out, " b n s c -> s b (n c)")
|
|
||||||
return self.to_out(out)
|
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
@@ -795,7 +788,10 @@ class GeneralDITTransformerBlock(nn.Module):
|
|||||||
crossattn_mask: Optional[torch.Tensor] = None,
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if extra_per_block_pos_emb is not None:
|
||||||
|
x = x + extra_per_block_pos_emb
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(
|
x = block(
|
||||||
x,
|
x,
|
||||||
|
|||||||
@@ -30,8 +30,6 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
|
||||||
|
|
||||||
from .patching import (
|
from .patching import (
|
||||||
Patcher,
|
Patcher,
|
||||||
Patcher3D,
|
Patcher3D,
|
||||||
@@ -402,8 +400,6 @@ class CausalAttnBlock(nn.Module):
|
|||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
)
|
)
|
||||||
|
|
||||||
self.optimized_attention = vae_attention()
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
@@ -417,7 +413,18 @@ class CausalAttnBlock(nn.Module):
|
|||||||
v, batch_size = time2batch(v)
|
v, batch_size = time2batch(v)
|
||||||
|
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
h_ = self.optimized_attention(q, k, v)
|
q = q.reshape(b, c, h * w)
|
||||||
|
q = q.permute(0, 2, 1)
|
||||||
|
k = k.reshape(b, c, h * w)
|
||||||
|
w_ = torch.bmm(q, k)
|
||||||
|
w_ = w_ * (int(c) ** (-0.5))
|
||||||
|
w_ = F.softmax(w_, dim=2)
|
||||||
|
|
||||||
|
# attend to values
|
||||||
|
v = v.reshape(b, c, h * w)
|
||||||
|
w_ = w_.permute(0, 2, 1)
|
||||||
|
h_ = torch.bmm(v, w_)
|
||||||
|
h_ = h_.reshape(b, c, h, w)
|
||||||
|
|
||||||
h_ = batch2time(h_, batch_size)
|
h_ = batch2time(h_, batch_size)
|
||||||
h_ = self.proj_out(h_)
|
h_ = self.proj_out(h_)
|
||||||
@@ -864,16 +871,18 @@ class EncoderFactorized(nn.Module):
|
|||||||
x = self.patcher3d(x)
|
x = self.patcher3d(x)
|
||||||
|
|
||||||
# downsampling
|
# downsampling
|
||||||
h = self.conv_in(x)
|
hs = [self.conv_in(x)]
|
||||||
for i_level in range(self.num_resolutions):
|
for i_level in range(self.num_resolutions):
|
||||||
for i_block in range(self.num_res_blocks):
|
for i_block in range(self.num_res_blocks):
|
||||||
h = self.down[i_level].block[i_block](h)
|
h = self.down[i_level].block[i_block](hs[-1])
|
||||||
if len(self.down[i_level].attn) > 0:
|
if len(self.down[i_level].attn) > 0:
|
||||||
h = self.down[i_level].attn[i_block](h)
|
h = self.down[i_level].attn[i_block](h)
|
||||||
|
hs.append(h)
|
||||||
if i_level != self.num_resolutions - 1:
|
if i_level != self.num_resolutions - 1:
|
||||||
h = self.down[i_level].downsample(h)
|
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
|
h = hs[-1]
|
||||||
h = self.mid.block_1(h)
|
h = self.mid.block_1(h)
|
||||||
h = self.mid.attn_1(h)
|
h = self.mid.attn_1(h)
|
||||||
h = self.mid.block_2(h)
|
h = self.mid.block_2(h)
|
||||||
|
|||||||
@@ -281,76 +281,54 @@ class UnPatcher3D(UnPatcher):
|
|||||||
hh = hh.to(dtype=dtype)
|
hh = hh.to(dtype=dtype)
|
||||||
|
|
||||||
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
|
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
|
||||||
del x
|
|
||||||
|
|
||||||
# Height height transposed convolutions.
|
# Height height transposed convolutions.
|
||||||
xll = F.conv_transpose3d(
|
xll = F.conv_transpose3d(
|
||||||
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
del xlll
|
|
||||||
|
|
||||||
xll += F.conv_transpose3d(
|
xll += F.conv_transpose3d(
|
||||||
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
del xllh
|
|
||||||
|
|
||||||
xlh = F.conv_transpose3d(
|
xlh = F.conv_transpose3d(
|
||||||
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
del xlhl
|
|
||||||
|
|
||||||
xlh += F.conv_transpose3d(
|
xlh += F.conv_transpose3d(
|
||||||
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
del xlhh
|
|
||||||
|
|
||||||
xhl = F.conv_transpose3d(
|
xhl = F.conv_transpose3d(
|
||||||
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
del xhll
|
|
||||||
|
|
||||||
xhl += F.conv_transpose3d(
|
xhl += F.conv_transpose3d(
|
||||||
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
del xhlh
|
|
||||||
|
|
||||||
xhh = F.conv_transpose3d(
|
xhh = F.conv_transpose3d(
|
||||||
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
del xhhl
|
|
||||||
|
|
||||||
xhh += F.conv_transpose3d(
|
xhh += F.conv_transpose3d(
|
||||||
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
|
||||||
)
|
)
|
||||||
del xhhh
|
|
||||||
|
|
||||||
# Handles width transposed convolutions.
|
# Handles width transposed convolutions.
|
||||||
xl = F.conv_transpose3d(
|
xl = F.conv_transpose3d(
|
||||||
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
)
|
)
|
||||||
del xll
|
|
||||||
|
|
||||||
xl += F.conv_transpose3d(
|
xl += F.conv_transpose3d(
|
||||||
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
)
|
)
|
||||||
del xlh
|
|
||||||
|
|
||||||
xh = F.conv_transpose3d(
|
xh = F.conv_transpose3d(
|
||||||
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
)
|
)
|
||||||
del xhl
|
|
||||||
|
|
||||||
xh += F.conv_transpose3d(
|
xh += F.conv_transpose3d(
|
||||||
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
|
||||||
)
|
)
|
||||||
del xhh
|
|
||||||
|
|
||||||
# Handles time axis transposed convolutions.
|
# Handles time axis transposed convolutions.
|
||||||
x = F.conv_transpose3d(
|
x = F.conv_transpose3d(
|
||||||
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||||
)
|
)
|
||||||
del xl
|
|
||||||
|
|
||||||
x += F.conv_transpose3d(
|
x += F.conv_transpose3d(
|
||||||
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ class GeneralDIT(nn.Module):
|
|||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.build_pos_embed(device=device, dtype=dtype)
|
self.build_pos_embed(device=device)
|
||||||
self.block_x_format = block_x_format
|
self.block_x_format = block_x_format
|
||||||
self.use_adaln_lora = use_adaln_lora
|
self.use_adaln_lora = use_adaln_lora
|
||||||
self.adaln_lora_dim = adaln_lora_dim
|
self.adaln_lora_dim = adaln_lora_dim
|
||||||
@@ -210,7 +210,7 @@ class GeneralDIT(nn.Module):
|
|||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_pos_embed(self, device=None, dtype=None):
|
def build_pos_embed(self, device=None):
|
||||||
if self.pos_emb_cls == "rope3d":
|
if self.pos_emb_cls == "rope3d":
|
||||||
cls_type = VideoRopePosition3DEmb
|
cls_type = VideoRopePosition3DEmb
|
||||||
else:
|
else:
|
||||||
@@ -242,7 +242,6 @@ class GeneralDIT(nn.Module):
|
|||||||
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
||||||
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
||||||
kwargs["device"] = device
|
kwargs["device"] = device
|
||||||
kwargs["dtype"] = dtype
|
|
||||||
self.extra_pos_embedder = LearnablePosEmbAxis(
|
self.extra_pos_embedder = LearnablePosEmbAxis(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@@ -477,8 +476,6 @@ class GeneralDIT(nn.Module):
|
|||||||
inputs["original_shape"],
|
inputs["original_shape"],
|
||||||
)
|
)
|
||||||
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
|
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
|
||||||
del inputs
|
|
||||||
|
|
||||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||||
assert (
|
assert (
|
||||||
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||||
@@ -489,8 +486,6 @@ class GeneralDIT(nn.Module):
|
|||||||
self.blocks["block0"].x_format == block.x_format
|
self.blocks["block0"].x_format == block.x_format
|
||||||
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
|
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
|
||||||
|
|
||||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
|
||||||
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
|
|
||||||
x = block(
|
x = block(
|
||||||
x,
|
x,
|
||||||
affline_emb_B_D,
|
affline_emb_B_D,
|
||||||
@@ -498,6 +493,7 @@ class GeneralDIT(nn.Module):
|
|||||||
crossattn_mask,
|
crossattn_mask,
|
||||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
|
extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||||
)
|
)
|
||||||
|
|
||||||
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||||
|
|||||||
@@ -173,7 +173,6 @@ class LearnablePosEmbAxis(VideoPositionEmb):
|
|||||||
len_w: int,
|
len_w: int,
|
||||||
len_t: int,
|
len_t: int,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -185,9 +184,9 @@ class LearnablePosEmbAxis(VideoPositionEmb):
|
|||||||
self.interpolation = interpolation
|
self.interpolation = interpolation
|
||||||
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
||||||
|
|
||||||
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
|
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device))
|
||||||
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device))
|
||||||
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device))
|
||||||
|
|
||||||
|
|
||||||
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
||||||
|
|||||||
@@ -89,8 +89,8 @@ class CausalContinuousVideoTokenizer(nn.Module):
|
|||||||
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
|
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
|
||||||
|
|
||||||
num_parameters = sum(param.numel() for param in self.parameters())
|
num_parameters = sum(param.numel() for param in self.parameters())
|
||||||
logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
|
logging.info(f"model={self.name}, num_parameters={num_parameters:,}")
|
||||||
logging.debug(
|
logging.info(
|
||||||
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
|
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -230,7 +230,8 @@ class SingleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
||||||
mod, _ = self.modulation(vec)
|
mod, _ = self.modulation(vec)
|
||||||
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||||
|
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|||||||
@@ -5,15 +5,8 @@ from torch import Tensor
|
|||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||||
q_shape = q.shape
|
q, k = apply_rope(q, k, pe)
|
||||||
k_shape = k.shape
|
|
||||||
|
|
||||||
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
|
|
||||||
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
|
|
||||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
|
||||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
|
||||||
|
|
||||||
heads = q.shape[1]
|
heads = q.shape[1]
|
||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||||
|
|||||||
@@ -293,17 +293,6 @@ def pytorch_attention(q, k, v):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def vae_attention():
|
|
||||||
if model_management.xformers_enabled_vae():
|
|
||||||
logging.info("Using xformers attention in VAE")
|
|
||||||
return xformers_attention
|
|
||||||
elif model_management.pytorch_attention_enabled():
|
|
||||||
logging.info("Using pytorch attention in VAE")
|
|
||||||
return pytorch_attention
|
|
||||||
else:
|
|
||||||
logging.info("Using split attention in VAE")
|
|
||||||
return normal_attention
|
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
class AttnBlock(nn.Module):
|
||||||
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -331,7 +320,15 @@ class AttnBlock(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
self.optimized_attention = vae_attention()
|
if model_management.xformers_enabled_vae():
|
||||||
|
logging.info("Using xformers attention in VAE")
|
||||||
|
self.optimized_attention = xformers_attention
|
||||||
|
elif model_management.pytorch_attention_enabled():
|
||||||
|
logging.info("Using pytorch attention in VAE")
|
||||||
|
self.optimized_attention = pytorch_attention
|
||||||
|
else:
|
||||||
|
logging.info("Using split attention in VAE")
|
||||||
|
self.optimized_attention = normal_attention
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
|
|||||||
@@ -388,8 +388,8 @@ class VAE:
|
|||||||
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
|
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
|
||||||
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
|
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
|
||||||
#TODO: these values are a bit off because this is not a standard VAE
|
#TODO: these values are a bit off because this is not a standard VAE
|
||||||
self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (220 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (500 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
|
|||||||
@@ -788,7 +788,7 @@ class HunyuanVideo(supported_models_base.BASE):
|
|||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.HunyuanVideo
|
latent_format = latent_formats.HunyuanVideo
|
||||||
|
|
||||||
memory_usage_factor = 1.8 #TODO
|
memory_usage_factor = 2.0 #TODO
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
@@ -839,7 +839,7 @@ class CosmosT2V(supported_models_base.BASE):
|
|||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.Cosmos1CV8x8x8
|
latent_format = latent_formats.Cosmos1CV8x8x8
|
||||||
|
|
||||||
memory_usage_factor = 1.6 #TODO
|
memory_usage_factor = 2.4 #TODO
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
|
||||||
|
|
||||||
|
|||||||
@@ -71,8 +71,8 @@ class CosmosImageToVideoLatent:
|
|||||||
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
|
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
|
||||||
|
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
out_latent["samples"] = latent
|
||||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
out_latent["noise_mask"] = mask
|
||||||
return (out_latent,)
|
return (out_latent,)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.11"
|
__version__ = "0.3.10"
|
||||||
|
|||||||
11
main.py
11
main.py
@@ -138,6 +138,8 @@ import server
|
|||||||
from server import BinaryEventTypes
|
from server import BinaryEventTypes
|
||||||
import nodes
|
import nodes
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from app.database.db import can_create_session, init_db
|
||||||
|
from app.model_processor import model_processor
|
||||||
|
|
||||||
def cuda_malloc_warning():
|
def cuda_malloc_warning():
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
@@ -262,6 +264,11 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
|
|
||||||
cuda_malloc_warning()
|
cuda_malloc_warning()
|
||||||
|
|
||||||
|
try:
|
||||||
|
init_db()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to initialize database. Please report this error as in future the database will be required: {e}")
|
||||||
|
|
||||||
prompt_server.add_routes()
|
prompt_server.add_routes()
|
||||||
hijack_progress(prompt_server)
|
hijack_progress(prompt_server)
|
||||||
|
|
||||||
@@ -270,6 +277,10 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
# Scan for changed model files and update db
|
||||||
|
if can_create_session():
|
||||||
|
model_processor.run()
|
||||||
|
|
||||||
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
||||||
call_on_start = None
|
call_on_start = None
|
||||||
if args.auto_launch:
|
if args.auto_launch:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.11"
|
version = "0.3.10"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ torch
|
|||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
torchaudio
|
torchaudio
|
||||||
numpy>=1.25.0
|
|
||||||
einops
|
einops
|
||||||
transformers>=4.28.1
|
transformers>=4.28.1
|
||||||
tokenizers>=0.13.3
|
tokenizers>=0.13.3
|
||||||
@@ -14,6 +13,8 @@ Pillow
|
|||||||
scipy
|
scipy
|
||||||
tqdm
|
tqdm
|
||||||
psutil
|
psutil
|
||||||
|
alembic
|
||||||
|
SQLAlchemy
|
||||||
|
|
||||||
#non essential dependencies:
|
#non essential dependencies:
|
||||||
kornia>=0.7.1
|
kornia>=0.7.1
|
||||||
|
|||||||
@@ -7,11 +7,33 @@ from PIL import Image
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
|
from app.database.models import Base, Model, Tag
|
||||||
|
from comfy.cli_args import args
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
pytestmark = (
|
pytestmark = (
|
||||||
pytest.mark.asyncio
|
pytest.mark.asyncio
|
||||||
) # This applies the asyncio mark to all test functions in the module
|
) # This applies the asyncio mark to all test functions in the module
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session():
|
||||||
|
# Configure in-memory database
|
||||||
|
args.database_url = "sqlite:///:memory:"
|
||||||
|
|
||||||
|
# Create engine and session factory
|
||||||
|
engine = create_engine(args.database_url)
|
||||||
|
Session = sessionmaker(bind=engine)
|
||||||
|
|
||||||
|
# Create all tables
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
|
||||||
|
# Patch Session factory
|
||||||
|
with patch('app.database.db.Session', Session):
|
||||||
|
yield Session()
|
||||||
|
|
||||||
|
Base.metadata.drop_all(engine)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def model_manager():
|
def model_manager():
|
||||||
return ModelFileManager()
|
return ModelFileManager()
|
||||||
@@ -60,3 +82,287 @@ async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
|
|||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
img.close()
|
img.close()
|
||||||
|
|
||||||
|
async def test_get_models(aiohttp_client, app, session):
|
||||||
|
tag = Tag(name='test_tag')
|
||||||
|
model = Model(
|
||||||
|
type='checkpoints',
|
||||||
|
path='model1.safetensors',
|
||||||
|
title='Test Model'
|
||||||
|
)
|
||||||
|
model.tags.append(tag)
|
||||||
|
session.add(tag)
|
||||||
|
session.add(model)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get('/v2/models')
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
assert len(data) == 1
|
||||||
|
assert data[0]['path'] == 'model1.safetensors'
|
||||||
|
assert len(data[0]['tags']) == 1
|
||||||
|
assert data[0]['tags'][0]['name'] == 'test_tag'
|
||||||
|
|
||||||
|
async def test_add_model(aiohttp_client, app, session):
|
||||||
|
tag = Tag(name='test_tag')
|
||||||
|
session.add(tag)
|
||||||
|
session.commit()
|
||||||
|
tag_id = tag.id
|
||||||
|
|
||||||
|
with patch('app.model_manager.model_processor') as mock_processor:
|
||||||
|
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post('/v2/models', json={
|
||||||
|
'type': 'checkpoints',
|
||||||
|
'path': 'model1.safetensors',
|
||||||
|
'title': 'Test Model',
|
||||||
|
'tags': [tag_id]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
assert data['path'] == 'model1.safetensors'
|
||||||
|
assert len(data['tags']) == 1
|
||||||
|
assert data['tags'][0]['name'] == 'test_tag'
|
||||||
|
|
||||||
|
# Ensure that models are re-processed after adding
|
||||||
|
mock_processor.run.assert_called_once()
|
||||||
|
|
||||||
|
async def test_delete_model(aiohttp_client, app, session):
|
||||||
|
model = Model(
|
||||||
|
type='checkpoints',
|
||||||
|
path='model1.safetensors',
|
||||||
|
title='Test Model'
|
||||||
|
)
|
||||||
|
session.add(model)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
with patch('app.model_manager.get_full_path', return_value=None):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
|
||||||
|
assert resp.status == 204
|
||||||
|
|
||||||
|
# Verify model was deleted
|
||||||
|
model = session.query(Model).first()
|
||||||
|
assert model is None
|
||||||
|
|
||||||
|
async def test_delete_model_file_exists(aiohttp_client, app, session):
|
||||||
|
model = Model(
|
||||||
|
type='checkpoints',
|
||||||
|
path='model1.safetensors',
|
||||||
|
title='Test Model'
|
||||||
|
)
|
||||||
|
session.add(model)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
data = await resp.json()
|
||||||
|
assert "file exists" in data["error"].lower()
|
||||||
|
|
||||||
|
# Verify model was not deleted
|
||||||
|
model = session.query(Model).first()
|
||||||
|
assert model is not None
|
||||||
|
assert model.path == 'model1.safetensors'
|
||||||
|
|
||||||
|
async def test_get_tags(aiohttp_client, app, session):
|
||||||
|
tags = [Tag(name='tag1'), Tag(name='tag2')]
|
||||||
|
for tag in tags:
|
||||||
|
session.add(tag)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get('/v2/tags')
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
assert len(data) == 2
|
||||||
|
assert {t['name'] for t in data} == {'tag1', 'tag2'}
|
||||||
|
|
||||||
|
async def test_create_tag(aiohttp_client, app, session):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post('/v2/tags', json={'name': 'new_tag'})
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
assert data['name'] == 'new_tag'
|
||||||
|
|
||||||
|
# Verify tag was created
|
||||||
|
tag = session.query(Tag).first()
|
||||||
|
assert tag.name == 'new_tag'
|
||||||
|
|
||||||
|
async def test_delete_tag(aiohttp_client, app, session):
|
||||||
|
tag = Tag(name='test_tag')
|
||||||
|
session.add(tag)
|
||||||
|
session.commit()
|
||||||
|
tag_id = tag.id
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.delete(f'/v2/tags?id={tag_id}')
|
||||||
|
assert resp.status == 204
|
||||||
|
|
||||||
|
# Verify tag was deleted
|
||||||
|
tag = session.query(Tag).first()
|
||||||
|
assert tag is None
|
||||||
|
|
||||||
|
async def test_add_model_tag(aiohttp_client, app, session):
|
||||||
|
tag = Tag(name='test_tag')
|
||||||
|
model = Model(
|
||||||
|
type='checkpoints',
|
||||||
|
path='model1.safetensors',
|
||||||
|
title='Test Model'
|
||||||
|
)
|
||||||
|
session.add(tag)
|
||||||
|
session.add(model)
|
||||||
|
session.commit()
|
||||||
|
tag_id = tag.id
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post('/v2/models/tags', json={
|
||||||
|
'tag': tag_id,
|
||||||
|
'type': 'checkpoints',
|
||||||
|
'path': 'model1.safetensors'
|
||||||
|
})
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
assert len(data['tags']) == 1
|
||||||
|
assert data['tags'][0]['name'] == 'test_tag'
|
||||||
|
|
||||||
|
async def test_delete_model_tag(aiohttp_client, app, session):
|
||||||
|
tag = Tag(name='test_tag')
|
||||||
|
model = Model(
|
||||||
|
type='checkpoints',
|
||||||
|
path='model1.safetensors',
|
||||||
|
title='Test Model'
|
||||||
|
)
|
||||||
|
model.tags.append(tag)
|
||||||
|
session.add(tag)
|
||||||
|
session.add(model)
|
||||||
|
session.commit()
|
||||||
|
tag_id = tag.id
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.delete(f'/v2/models/tags?tag={tag_id}&type=checkpoints&path=model1.safetensors')
|
||||||
|
assert resp.status == 204
|
||||||
|
|
||||||
|
# Verify tag was removed
|
||||||
|
model = session.query(Model).first()
|
||||||
|
assert len(model.tags) == 0
|
||||||
|
|
||||||
|
async def test_add_model_duplicate(aiohttp_client, app, session):
|
||||||
|
model = Model(
|
||||||
|
type='checkpoints',
|
||||||
|
path='model1.safetensors',
|
||||||
|
title='Test Model'
|
||||||
|
)
|
||||||
|
session.add(model)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post('/v2/models', json={
|
||||||
|
'type': 'checkpoints',
|
||||||
|
'path': 'model1.safetensors',
|
||||||
|
'title': 'Duplicate Model'
|
||||||
|
})
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
async def test_add_model_missing_fields(aiohttp_client, app, session):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post('/v2/models', json={})
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
async def test_add_tag_missing_name(aiohttp_client, app, session):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post('/v2/tags', json={})
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
async def test_delete_model_not_found(aiohttp_client, app, session):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.delete('/v2/models?type=checkpoints&path=nonexistent.safetensors')
|
||||||
|
assert resp.status == 404
|
||||||
|
|
||||||
|
async def test_delete_tag_not_found(aiohttp_client, app, session):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.delete('/v2/tags?id=999')
|
||||||
|
assert resp.status == 404
|
||||||
|
|
||||||
|
async def test_add_model_missing_path(aiohttp_client, app, session):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post('/v2/models', json={
|
||||||
|
'type': 'checkpoints',
|
||||||
|
'title': 'Test Model'
|
||||||
|
})
|
||||||
|
assert resp.status == 400
|
||||||
|
data = await resp.json()
|
||||||
|
assert "path" in data["error"].lower()
|
||||||
|
|
||||||
|
async def test_add_model_invalid_field(aiohttp_client, app, session):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post('/v2/models', json={
|
||||||
|
'type': 'checkpoints',
|
||||||
|
'path': 'model1.safetensors',
|
||||||
|
'invalid_field': 'some value'
|
||||||
|
})
|
||||||
|
assert resp.status == 400
|
||||||
|
data = await resp.json()
|
||||||
|
assert "invalid field" in data["error"].lower()
|
||||||
|
|
||||||
|
async def test_add_model_nonexistent_file(aiohttp_client, app, session):
|
||||||
|
with patch('app.model_manager.get_full_path', return_value=None):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post('/v2/models', json={
|
||||||
|
'type': 'checkpoints',
|
||||||
|
'path': 'nonexistent.safetensors'
|
||||||
|
})
|
||||||
|
assert resp.status == 404
|
||||||
|
data = await resp.json()
|
||||||
|
assert "file" in data["error"].lower()
|
||||||
|
|
||||||
|
async def test_add_model_invalid_tag(aiohttp_client, app, session):
|
||||||
|
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post('/v2/models', json={
|
||||||
|
'type': 'checkpoints',
|
||||||
|
'path': 'model1.safetensors',
|
||||||
|
'tags': [999] # Non-existent tag ID
|
||||||
|
})
|
||||||
|
assert resp.status == 404
|
||||||
|
data = await resp.json()
|
||||||
|
assert "tag" in data["error"].lower()
|
||||||
|
|
||||||
|
async def test_add_tag_to_nonexistent_model(aiohttp_client, app, session):
|
||||||
|
# Create a tag but no model
|
||||||
|
tag = Tag(name='test_tag')
|
||||||
|
session.add(tag)
|
||||||
|
session.commit()
|
||||||
|
tag_id = tag.id
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post('/v2/models/tags', json={
|
||||||
|
'tag': tag_id,
|
||||||
|
'type': 'checkpoints',
|
||||||
|
'path': 'nonexistent.safetensors'
|
||||||
|
})
|
||||||
|
assert resp.status == 404
|
||||||
|
data = await resp.json()
|
||||||
|
assert "model" in data["error"].lower()
|
||||||
|
|
||||||
|
async def test_delete_model_tag_invalid_tag_id(aiohttp_client, app, session):
|
||||||
|
# Create a model first
|
||||||
|
model = Model(
|
||||||
|
type='checkpoints',
|
||||||
|
path='model1.safetensors',
|
||||||
|
title='Test Model'
|
||||||
|
)
|
||||||
|
session.add(model)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.delete('/v2/models/tags?tag=not_a_number&type=checkpoint&path=model1.safetensors')
|
||||||
|
assert resp.status == 400
|
||||||
|
data = await resp.json()
|
||||||
|
assert "invalid tag id" in data["error"].lower()
|
||||||
|
|
||||||
|
|||||||
12
utils/web.py
Normal file
12
utils/web.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class DateTimeEncoder(json.JSONEncoder):
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, datetime):
|
||||||
|
return obj.isoformat()
|
||||||
|
return super().default(obj)
|
||||||
|
|
||||||
|
|
||||||
|
dumps = DateTimeEncoder().encode
|
||||||
Reference in New Issue
Block a user