Compare commits
3 Commits
model_mana
...
model_mana
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
01110de8a3 | ||
|
|
785a220757 | ||
|
|
b6b475191d |
119
alembic.ini
119
alembic.ini
@@ -1,119 +0,0 @@
|
|||||||
# A generic, single database configuration.
|
|
||||||
|
|
||||||
[alembic]
|
|
||||||
# path to migration scripts
|
|
||||||
# Use forward slashes (/) also on windows to provide an os agnostic path
|
|
||||||
script_location = alembic_db
|
|
||||||
|
|
||||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
|
||||||
# Uncomment the line below if you want the files to be prepended with date and time
|
|
||||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
|
||||||
# for all available tokens
|
|
||||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
|
||||||
|
|
||||||
# sys.path path, will be prepended to sys.path if present.
|
|
||||||
# defaults to the current working directory.
|
|
||||||
prepend_sys_path = .
|
|
||||||
|
|
||||||
# timezone to use when rendering the date within the migration file
|
|
||||||
# as well as the filename.
|
|
||||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
|
||||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
|
||||||
# string value is passed to ZoneInfo()
|
|
||||||
# leave blank for localtime
|
|
||||||
# timezone =
|
|
||||||
|
|
||||||
# max length of characters to apply to the "slug" field
|
|
||||||
# truncate_slug_length = 40
|
|
||||||
|
|
||||||
# set to 'true' to run the environment during
|
|
||||||
# the 'revision' command, regardless of autogenerate
|
|
||||||
# revision_environment = false
|
|
||||||
|
|
||||||
# set to 'true' to allow .pyc and .pyo files without
|
|
||||||
# a source .py file to be detected as revisions in the
|
|
||||||
# versions/ directory
|
|
||||||
# sourceless = false
|
|
||||||
|
|
||||||
# version location specification; This defaults
|
|
||||||
# to alembic_db/versions. When using multiple version
|
|
||||||
# directories, initial revisions must be specified with --version-path.
|
|
||||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
|
||||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
|
|
||||||
|
|
||||||
# version path separator; As mentioned above, this is the character used to split
|
|
||||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
|
||||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
|
||||||
# Valid values for version_path_separator are:
|
|
||||||
#
|
|
||||||
# version_path_separator = :
|
|
||||||
# version_path_separator = ;
|
|
||||||
# version_path_separator = space
|
|
||||||
# version_path_separator = newline
|
|
||||||
#
|
|
||||||
# Use os.pathsep. Default configuration used for new projects.
|
|
||||||
version_path_separator = os
|
|
||||||
|
|
||||||
# set to 'true' to search source files recursively
|
|
||||||
# in each "version_locations" directory
|
|
||||||
# new in Alembic version 1.10
|
|
||||||
# recursive_version_locations = false
|
|
||||||
|
|
||||||
# the output encoding used when revision files
|
|
||||||
# are written from script.py.mako
|
|
||||||
# output_encoding = utf-8
|
|
||||||
|
|
||||||
sqlalchemy.url = sqlite:///user/comfyui.db
|
|
||||||
|
|
||||||
|
|
||||||
[post_write_hooks]
|
|
||||||
# post_write_hooks defines scripts or Python functions that are run
|
|
||||||
# on newly generated revision scripts. See the documentation for further
|
|
||||||
# detail and examples
|
|
||||||
|
|
||||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
|
||||||
# hooks = black
|
|
||||||
# black.type = console_scripts
|
|
||||||
# black.entrypoint = black
|
|
||||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
|
||||||
|
|
||||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
|
||||||
# hooks = ruff
|
|
||||||
# ruff.type = exec
|
|
||||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
|
||||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
|
||||||
|
|
||||||
# Logging configuration
|
|
||||||
[loggers]
|
|
||||||
keys = root,sqlalchemy,alembic
|
|
||||||
|
|
||||||
[handlers]
|
|
||||||
keys = console
|
|
||||||
|
|
||||||
[formatters]
|
|
||||||
keys = generic
|
|
||||||
|
|
||||||
[logger_root]
|
|
||||||
level = WARNING
|
|
||||||
handlers = console
|
|
||||||
qualname =
|
|
||||||
|
|
||||||
[logger_sqlalchemy]
|
|
||||||
level = WARNING
|
|
||||||
handlers =
|
|
||||||
qualname = sqlalchemy.engine
|
|
||||||
|
|
||||||
[logger_alembic]
|
|
||||||
level = INFO
|
|
||||||
handlers =
|
|
||||||
qualname = alembic
|
|
||||||
|
|
||||||
[handler_console]
|
|
||||||
class = StreamHandler
|
|
||||||
args = (sys.stderr,)
|
|
||||||
level = NOTSET
|
|
||||||
formatter = generic
|
|
||||||
|
|
||||||
[formatter_generic]
|
|
||||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
|
||||||
datefmt = %H:%M:%S
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
## Generate new revision
|
|
||||||
1. Update models in `/app/database/models.py`
|
|
||||||
2. Run `alembic revision --autogenerate -m "{your message}"`
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
from logging.config import fileConfig
|
|
||||||
|
|
||||||
from sqlalchemy import engine_from_config
|
|
||||||
from sqlalchemy import pool
|
|
||||||
|
|
||||||
from alembic import context
|
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
|
||||||
# access to the values within the .ini file in use.
|
|
||||||
config = context.config
|
|
||||||
|
|
||||||
# Interpret the config file for Python logging.
|
|
||||||
# This line sets up loggers basically.
|
|
||||||
if config.config_file_name is not None:
|
|
||||||
fileConfig(config.config_file_name)
|
|
||||||
|
|
||||||
from app.database.models import Base
|
|
||||||
target_metadata = Base.metadata
|
|
||||||
|
|
||||||
# other values from the config, defined by the needs of env.py,
|
|
||||||
# can be acquired:
|
|
||||||
# my_important_option = config.get_main_option("my_important_option")
|
|
||||||
# ... etc.
|
|
||||||
|
|
||||||
|
|
||||||
def run_migrations_offline() -> None:
|
|
||||||
"""Run migrations in 'offline' mode.
|
|
||||||
|
|
||||||
This configures the context with just a URL
|
|
||||||
and not an Engine, though an Engine is acceptable
|
|
||||||
here as well. By skipping the Engine creation
|
|
||||||
we don't even need a DBAPI to be available.
|
|
||||||
|
|
||||||
Calls to context.execute() here emit the given string to the
|
|
||||||
script output.
|
|
||||||
|
|
||||||
"""
|
|
||||||
url = config.get_main_option("sqlalchemy.url")
|
|
||||||
context.configure(
|
|
||||||
url=url,
|
|
||||||
target_metadata=target_metadata,
|
|
||||||
literal_binds=True,
|
|
||||||
dialect_opts={"paramstyle": "named"},
|
|
||||||
)
|
|
||||||
|
|
||||||
with context.begin_transaction():
|
|
||||||
context.run_migrations()
|
|
||||||
|
|
||||||
|
|
||||||
def run_migrations_online() -> None:
|
|
||||||
"""Run migrations in 'online' mode.
|
|
||||||
|
|
||||||
In this scenario we need to create an Engine
|
|
||||||
and associate a connection with the context.
|
|
||||||
|
|
||||||
"""
|
|
||||||
connectable = engine_from_config(
|
|
||||||
config.get_section(config.config_ini_section, {}),
|
|
||||||
prefix="sqlalchemy.",
|
|
||||||
poolclass=pool.NullPool,
|
|
||||||
)
|
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
|
||||||
context.configure(
|
|
||||||
connection=connection, target_metadata=target_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
with context.begin_transaction():
|
|
||||||
context.run_migrations()
|
|
||||||
|
|
||||||
|
|
||||||
if context.is_offline_mode():
|
|
||||||
run_migrations_offline()
|
|
||||||
else:
|
|
||||||
run_migrations_online()
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
"""${message}
|
|
||||||
|
|
||||||
Revision ID: ${up_revision}
|
|
||||||
Revises: ${down_revision | comma,n}
|
|
||||||
Create Date: ${create_date}
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
${imports if imports else ""}
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = ${repr(up_revision)}
|
|
||||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
|
||||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
${upgrades if upgrades else "pass"}
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
${downgrades if downgrades else "pass"}
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
"""init
|
|
||||||
|
|
||||||
Revision ID: 2fb22c4fff36
|
|
||||||
Revises:
|
|
||||||
Create Date: 2025-03-27 19:00:47.686079
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = '2fb22c4fff36'
|
|
||||||
down_revision: Union[str, None] = None
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.create_table('model',
|
|
||||||
sa.Column('type', sa.Text(), nullable=False),
|
|
||||||
sa.Column('path', sa.Text(), nullable=False),
|
|
||||||
sa.Column('title', sa.Text(), nullable=True),
|
|
||||||
sa.Column('description', sa.Text(), nullable=True),
|
|
||||||
sa.Column('architecture', sa.Text(), nullable=True),
|
|
||||||
sa.Column('hash', sa.Text(), nullable=True),
|
|
||||||
sa.Column('source_url', sa.Text(), nullable=True),
|
|
||||||
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
|
||||||
sa.PrimaryKeyConstraint('type', 'path')
|
|
||||||
)
|
|
||||||
op.create_table('tag',
|
|
||||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
|
||||||
sa.Column('name', sa.Text(), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint('id'),
|
|
||||||
sa.UniqueConstraint('name')
|
|
||||||
)
|
|
||||||
op.create_table('model_tag',
|
|
||||||
sa.Column('model_type', sa.Text(), nullable=False),
|
|
||||||
sa.Column('model_path', sa.Text(), nullable=False),
|
|
||||||
sa.Column('tag_id', sa.Integer(), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['model_type', 'model_path'], ['model.type', 'model.path'], ondelete='CASCADE'),
|
|
||||||
sa.ForeignKeyConstraint(['tag_id'], ['tag.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('model_type', 'model_path', 'tag_id')
|
|
||||||
)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_table('model_tag')
|
|
||||||
op.drop_table('tag')
|
|
||||||
op.drop_table('model')
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
0
app/database/__init__.py
Normal file
0
app/database/__init__.py
Normal file
@@ -1,118 +1,126 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import sqlite3
|
||||||
import sys
|
from contextlib import contextmanager
|
||||||
from app.database.models import Tag
|
from queue import Queue, Empty, Full
|
||||||
|
import threading
|
||||||
|
from app.database.updater import DatabaseUpdater
|
||||||
|
import folder_paths
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
try:
|
|
||||||
import alembic
|
|
||||||
import sqlalchemy
|
|
||||||
except ImportError as e:
|
|
||||||
req_path = os.path.abspath(
|
|
||||||
os.path.join(os.path.dirname(__file__), "../..", "requirements.txt")
|
|
||||||
)
|
|
||||||
logging.error(
|
|
||||||
f"\n\n********** ERROR ***********\n\nRequirements are not installed ({e}). Please install the requirements.txt file by running:\n{sys.executable} -s -m pip install -r {req_path}\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n"
|
|
||||||
)
|
|
||||||
exit(-1)
|
|
||||||
|
|
||||||
from alembic import command
|
class Database:
|
||||||
from alembic.config import Config
|
def __init__(self, database_path=None, pool_size=1):
|
||||||
from alembic.runtime.migration import MigrationContext
|
if database_path is None:
|
||||||
from alembic.script import ScriptDirectory
|
self.exists = False
|
||||||
from sqlalchemy import create_engine
|
database_path = "file::memory:?cache=shared"
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
Session = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_alembic_config():
|
|
||||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
|
||||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
|
||||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
|
||||||
|
|
||||||
config = Config(config_path)
|
|
||||||
config.set_main_option("script_location", scripts_path)
|
|
||||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def get_db_path():
|
|
||||||
url = args.database_url
|
|
||||||
if url.startswith("sqlite:///"):
|
|
||||||
return url.split("///")[1]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
|
||||||
|
|
||||||
|
|
||||||
def init_db():
|
|
||||||
db_url = args.database_url
|
|
||||||
logging.debug(f"Database URL: {db_url}")
|
|
||||||
|
|
||||||
config = get_alembic_config()
|
|
||||||
|
|
||||||
# Check if we need to upgrade
|
|
||||||
engine = create_engine(db_url)
|
|
||||||
conn = engine.connect()
|
|
||||||
|
|
||||||
context = MigrationContext.configure(conn)
|
|
||||||
current_rev = context.get_current_revision()
|
|
||||||
|
|
||||||
script = ScriptDirectory.from_config(config)
|
|
||||||
target_rev = script.get_current_head()
|
|
||||||
|
|
||||||
if current_rev != target_rev:
|
|
||||||
# Backup the database pre upgrade
|
|
||||||
db_path = get_db_path()
|
|
||||||
backup_path = db_path + ".bkp"
|
|
||||||
if os.path.exists(db_path):
|
|
||||||
shutil.copy(db_path, backup_path)
|
|
||||||
else:
|
else:
|
||||||
backup_path = None
|
self.exists = os.path.exists(database_path)
|
||||||
|
|
||||||
|
self.database_path = database_path
|
||||||
|
self.pool_size = pool_size
|
||||||
|
# Store connections in a pool, default to 1 as normal usage is going to be from a single thread at a time
|
||||||
|
self.connection_pool: Queue = Queue(maxsize=pool_size)
|
||||||
|
self._db_lock = threading.Lock()
|
||||||
|
self._initialized = False
|
||||||
|
self._closing = False
|
||||||
|
self._after_update_callbacks = []
|
||||||
|
|
||||||
|
def _setup(self):
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._db_lock:
|
||||||
|
if not self._initialized:
|
||||||
|
self._make_db()
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def _create_connection(self):
|
||||||
|
# TODO: Catch error for sqlite lib missing on linux
|
||||||
|
logging.info(f"Creating connection to {self.database_path}")
|
||||||
|
conn = sqlite3.connect(
|
||||||
|
self.database_path,
|
||||||
|
check_same_thread=False,
|
||||||
|
uri=self.database_path.startswith("file::"),
|
||||||
|
)
|
||||||
|
conn.execute("PRAGMA foreign_keys = ON")
|
||||||
|
self.exists = True
|
||||||
|
logging.info(f"Connected!")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _make_db(self):
|
||||||
|
with self._get_connection() as con:
|
||||||
|
updater = DatabaseUpdater(con, self.database_path)
|
||||||
|
result = updater.update()
|
||||||
|
if result is not None:
|
||||||
|
old_version, new_version = result
|
||||||
|
|
||||||
|
for callback in self._after_update_callbacks:
|
||||||
|
callback(old_version, new_version)
|
||||||
|
|
||||||
|
def _transform(self, row, columns):
|
||||||
|
return {col.name: value for value, col in zip(row, columns)}
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _get_connection(self):
|
||||||
|
if self._closing:
|
||||||
|
raise Exception("Database is shutting down")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
command.upgrade(config, target_rev)
|
# Try to get connection from pool
|
||||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
connection = self.connection_pool.get_nowait()
|
||||||
except Exception as e:
|
except Empty:
|
||||||
if backup_path:
|
# Create new connection if pool is empty
|
||||||
# Restore the database from backup if upgrade fails
|
connection = self._create_connection()
|
||||||
shutil.copy(backup_path, db_path)
|
|
||||||
os.remove(backup_path)
|
|
||||||
logging.error(f"Error upgrading database: {e}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
global Session
|
try:
|
||||||
Session = sessionmaker(bind=engine)
|
yield connection
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
# Try to add to pool if it's empty
|
||||||
|
self.connection_pool.put_nowait(connection)
|
||||||
|
except Full:
|
||||||
|
# Pool is full, close the connection
|
||||||
|
connection.close()
|
||||||
|
|
||||||
if not current_rev:
|
@contextmanager
|
||||||
# Init db, populate models
|
def get_connection(self):
|
||||||
from app.model_processor import model_processor
|
# Setup the database if it's not already initialized
|
||||||
|
self._setup()
|
||||||
|
with self._get_connection() as connection:
|
||||||
|
yield connection
|
||||||
|
|
||||||
session = create_session()
|
def execute(self, sql, *args):
|
||||||
model_processor.populate_models(session)
|
with self.get_connection() as connection:
|
||||||
|
cursor = connection.execute(sql, args)
|
||||||
|
results = cursor.fetchall()
|
||||||
|
return results
|
||||||
|
|
||||||
# populate tags
|
def register_after_update_callback(self, callback):
|
||||||
tags = (
|
self._after_update_callbacks.append(callback)
|
||||||
"character",
|
|
||||||
"style",
|
|
||||||
"concept",
|
|
||||||
"clothing",
|
|
||||||
"pose",
|
|
||||||
"background",
|
|
||||||
"vehicle",
|
|
||||||
"object",
|
|
||||||
"animal",
|
|
||||||
"action",
|
|
||||||
)
|
|
||||||
for tag in tags:
|
|
||||||
session.add(Tag(name=tag))
|
|
||||||
|
|
||||||
session.commit()
|
def close(self):
|
||||||
|
if self._closing:
|
||||||
|
return
|
||||||
|
# Drain and close all connections in the pool
|
||||||
|
self._closing = True
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
conn = self.connection_pool.get_nowait()
|
||||||
|
conn.close()
|
||||||
|
except Empty:
|
||||||
|
break
|
||||||
|
self._closing = False
|
||||||
|
|
||||||
def can_create_session():
|
def __del__(self):
|
||||||
return Session is not None
|
try:
|
||||||
|
self.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
def create_session():
|
|
||||||
return Session()
|
# Create a global instance
|
||||||
|
db_path = None
|
||||||
|
if not args.memory_database:
|
||||||
|
db_path = folder_paths.get_user_directory() + "/comfyui.db"
|
||||||
|
db = Database(db_path)
|
||||||
|
|||||||
343
app/database/entities.py
Normal file
343
app/database/entities.py
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
from typing import Optional, Any, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import wraps
|
||||||
|
from aiohttp import web
|
||||||
|
from app.database.db import db
|
||||||
|
|
||||||
|
primitives = (bool, str, int, float, type(None))
|
||||||
|
|
||||||
|
|
||||||
|
def is_primitive(obj):
|
||||||
|
return isinstance(obj, primitives)
|
||||||
|
|
||||||
|
|
||||||
|
class EntityError(Exception):
|
||||||
|
def __init__(
|
||||||
|
self, message: str, field: str = None, value: Any = None, status_code: int = 400
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
self.field = field
|
||||||
|
self.value = value
|
||||||
|
self.status_code = status_code
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
result = {"message": self.message}
|
||||||
|
if self.field is not None:
|
||||||
|
result["field"] = self.field
|
||||||
|
if self.value is not None:
|
||||||
|
result["value"] = self.value
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.message} {self.field} {self.value}"
|
||||||
|
|
||||||
|
|
||||||
|
class EntityCommon(dict):
|
||||||
|
@classmethod
|
||||||
|
def _get_route(cls, include_key: bool):
|
||||||
|
route = f"/db/{cls._table_name}"
|
||||||
|
if include_key:
|
||||||
|
route += "".join([f"/{{{k}}}" for k in cls._key_columns])
|
||||||
|
return route
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _register_route(cls, routes, verb: str, include_key: bool, handler: Callable):
|
||||||
|
route = cls._get_route(include_key)
|
||||||
|
|
||||||
|
@getattr(routes, verb)(route)
|
||||||
|
async def _(request):
|
||||||
|
try:
|
||||||
|
data = await handler(request)
|
||||||
|
if data is None:
|
||||||
|
return web.json_response(status=204)
|
||||||
|
|
||||||
|
return web.json_response(data)
|
||||||
|
except EntityError as e:
|
||||||
|
return web.json_response(e.to_json(), status=e.status_code)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _transform(cls, row: list[Any]):
|
||||||
|
return {col: value for col, value in zip(cls._columns, row)}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _transform_rows(cls, rows: list[list[Any]]):
|
||||||
|
return [cls._transform(row) for row in rows]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_key(cls, request):
|
||||||
|
return {key: request.match_info.get(key, None) for key in cls._key_columns}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate(cls, fields: list[str], data: dict, allow_missing: bool = False):
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise EntityError("Invalid data")
|
||||||
|
|
||||||
|
# Ensure all required fields are present
|
||||||
|
for field in data:
|
||||||
|
if field not in fields:
|
||||||
|
raise EntityError("Unknown field", field)
|
||||||
|
|
||||||
|
for key in fields:
|
||||||
|
col = cls._columns[key]
|
||||||
|
if key not in data:
|
||||||
|
if col.required and not allow_missing:
|
||||||
|
raise EntityError("Missing field", key)
|
||||||
|
else:
|
||||||
|
# e.g. for updates, we allow missing fields
|
||||||
|
continue
|
||||||
|
elif data[key] is None and col.required:
|
||||||
|
# Dont allow None for required fields
|
||||||
|
raise EntityError("Required field", key)
|
||||||
|
|
||||||
|
# Validate data type
|
||||||
|
value = data[key]
|
||||||
|
|
||||||
|
if value is not None and not is_primitive(value):
|
||||||
|
raise EntityError("Invalid value", key, value)
|
||||||
|
|
||||||
|
try:
|
||||||
|
type = col.type
|
||||||
|
if value is not None and not isinstance(value, type):
|
||||||
|
value = type(value)
|
||||||
|
result[key] = value
|
||||||
|
except Exception:
|
||||||
|
raise EntityError("Invalid value", key, value)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_id(cls, id: dict):
|
||||||
|
return cls._validate(cls._key_columns, id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_data(cls, data: dict, allow_missing: bool = False):
|
||||||
|
return cls._validate(cls._columns.keys(), data, allow_missing)
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
if name in self._columns:
|
||||||
|
self[name] = value
|
||||||
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name in self:
|
||||||
|
return self[name]
|
||||||
|
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
|
||||||
|
|
||||||
|
|
||||||
|
class GetEntity(EntityCommon):
|
||||||
|
@classmethod
|
||||||
|
def get(cls, top: Optional[int] = None, where: Optional[str] = None):
|
||||||
|
limit = ""
|
||||||
|
if top is not None and isinstance(top, int):
|
||||||
|
limit = f" LIMIT {top}"
|
||||||
|
result = db.execute(
|
||||||
|
f"SELECT * FROM {cls._table_name}{limit}{f' WHERE {where}' if where else ''}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map each row in result to an instance of the class
|
||||||
|
return cls._transform_rows(result)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_route(cls, routes):
|
||||||
|
async def get_handler(request):
|
||||||
|
top = request.rel_url.query.get("top", None)
|
||||||
|
if top is not None:
|
||||||
|
try:
|
||||||
|
top = int(top)
|
||||||
|
except Exception:
|
||||||
|
raise EntityError("Invalid top parameter", "top", top)
|
||||||
|
return cls.get(top)
|
||||||
|
|
||||||
|
cls._register_route(routes, "get", False, get_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class GetEntityById(EntityCommon):
|
||||||
|
@classmethod
|
||||||
|
def get_by_id(cls, id: dict):
|
||||||
|
id = cls._validate_id(id)
|
||||||
|
|
||||||
|
result = db.execute(
|
||||||
|
f"SELECT * FROM {cls._table_name} WHERE {cls._where_clause}",
|
||||||
|
*[id[key] for key in cls._key_columns],
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls._transform_rows(result)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_route(cls, routes):
|
||||||
|
async def get_by_id_handler(request):
|
||||||
|
id = cls._extract_key(request)
|
||||||
|
return cls.get_by_id(id)
|
||||||
|
|
||||||
|
cls._register_route(routes, "get", True, get_by_id_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateEntity(EntityCommon):
|
||||||
|
@classmethod
|
||||||
|
def create(cls, data: dict, allow_upsert: bool = False):
|
||||||
|
data = cls._validate_data(data)
|
||||||
|
values = ", ".join(["?"] * len(data))
|
||||||
|
on_conflict = ""
|
||||||
|
|
||||||
|
data_keys = ", ".join(list(data.keys()))
|
||||||
|
if allow_upsert:
|
||||||
|
# Remove key columns from data
|
||||||
|
upsert_keys = [key for key in data if key not in cls._key_columns]
|
||||||
|
|
||||||
|
set_clause = ", ".join([f"{k} = excluded.{k}" for k in upsert_keys])
|
||||||
|
on_conflict = f" ON CONFLICT ({', '.join(cls._key_columns)}) DO UPDATE SET {set_clause}"
|
||||||
|
sql = f"INSERT INTO {cls._table_name} ({data_keys}) VALUES ({values}){on_conflict} RETURNING *"
|
||||||
|
result = db.execute(
|
||||||
|
sql,
|
||||||
|
*[data[key] for key in data],
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(result) == 0:
|
||||||
|
raise EntityError("Failed to create entity", status_code=500)
|
||||||
|
|
||||||
|
return cls._transform_rows(result)[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_route(cls, routes):
|
||||||
|
async def create_handler(request):
|
||||||
|
data = await request.json()
|
||||||
|
return cls.create(data)
|
||||||
|
|
||||||
|
cls._register_route(routes, "post", False, create_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateEntity(EntityCommon):
|
||||||
|
@classmethod
|
||||||
|
def update(cls, id: list, data: dict):
|
||||||
|
id = cls._validate_id(id)
|
||||||
|
data = cls._validate_data(data, allow_missing=True)
|
||||||
|
|
||||||
|
sql = f"UPDATE {cls._table_name} SET {', '.join([f'{k} = ?' for k in data])} WHERE {cls._where_clause} RETURNING *"
|
||||||
|
result = db.execute(
|
||||||
|
sql,
|
||||||
|
*[data[key] for key in data],
|
||||||
|
*[id[key] for key in cls._key_columns],
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(result) == 0:
|
||||||
|
raise EntityError("Failed to update entity", status_code=404)
|
||||||
|
|
||||||
|
return cls._transform_rows(result)[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_route(cls, routes):
|
||||||
|
async def update_handler(request):
|
||||||
|
id = cls._extract_key(request)
|
||||||
|
data = await request.json()
|
||||||
|
return cls.update(id, data)
|
||||||
|
|
||||||
|
cls._register_route(routes, "patch", True, update_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class UpsertEntity(CreateEntity):
|
||||||
|
@classmethod
|
||||||
|
def upsert(cls, data: dict):
|
||||||
|
return cls.create(data, allow_upsert=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_route(cls, routes):
|
||||||
|
async def upsert_handler(request):
|
||||||
|
data = await request.json()
|
||||||
|
return cls.upsert(data)
|
||||||
|
|
||||||
|
cls._register_route(routes, "put", False, upsert_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteEntity(EntityCommon):
|
||||||
|
@classmethod
|
||||||
|
def delete(cls, id: list):
|
||||||
|
id = cls._validate_id(id)
|
||||||
|
db.execute(
|
||||||
|
f"DELETE FROM {cls._table_name} WHERE {cls._where_clause}",
|
||||||
|
*[id[key] for key in cls._key_columns],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_route(cls, routes):
|
||||||
|
async def delete_handler(request):
|
||||||
|
id = cls._extract_key(request)
|
||||||
|
cls.delete(id)
|
||||||
|
|
||||||
|
cls._register_route(routes, "delete", True, delete_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEntity(GetEntity, CreateEntity, UpdateEntity, DeleteEntity, GetEntityById):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Column:
|
||||||
|
type: Any
|
||||||
|
required: bool = False
|
||||||
|
key: bool = False
|
||||||
|
default: Any = None
|
||||||
|
|
||||||
|
|
||||||
|
def column(type_: Any, required: bool = False, key: bool = False, default: Any = None):
|
||||||
|
return Column(type_, required, key, default)
|
||||||
|
|
||||||
|
|
||||||
|
def table(table_name: str):
|
||||||
|
def decorator(cls):
|
||||||
|
# Store table name
|
||||||
|
cls._table_name = table_name
|
||||||
|
|
||||||
|
# Process column definitions
|
||||||
|
columns: dict[str, Column] = {}
|
||||||
|
for attr_name, attr_value in cls.__dict__.items():
|
||||||
|
if isinstance(attr_value, Column):
|
||||||
|
columns[attr_name] = attr_value
|
||||||
|
|
||||||
|
# Store columns metadata
|
||||||
|
cls._columns = columns
|
||||||
|
cls._key_columns = [col for col in columns if columns[col].key]
|
||||||
|
cls._column_csv = ", ".join([col for col in columns])
|
||||||
|
cls._where_clause = " AND ".join([f"{col} = ?" for col in cls._key_columns])
|
||||||
|
|
||||||
|
# Add initialization
|
||||||
|
original_init = cls.__init__
|
||||||
|
|
||||||
|
@wraps(original_init)
|
||||||
|
def new_init(self, *args, **kwargs):
|
||||||
|
# Initialize columns with default values
|
||||||
|
for col_name, col_def in cls._columns.items():
|
||||||
|
setattr(self, col_name, col_def.default)
|
||||||
|
# Call original init
|
||||||
|
original_init(self, *args, **kwargs)
|
||||||
|
|
||||||
|
cls.__init__ = new_init
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
@table("models")
|
||||||
|
class Model(BaseEntity):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
path: str = column(str, required=True)
|
||||||
|
name: str = column(str, required=True)
|
||||||
|
description: Optional[str] = column(str)
|
||||||
|
architecture: Optional[str] = column(str)
|
||||||
|
type: str = column(str, required=True)
|
||||||
|
hash: Optional[str] = column(str)
|
||||||
|
source_url: Optional[str] = column(str)
|
||||||
|
|
||||||
|
return Model
|
||||||
|
|
||||||
|
|
||||||
|
@table("test")
|
||||||
|
class Test(GetEntity, CreateEntity):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
test: str = column(str, required=True)
|
||||||
|
|
||||||
|
|
||||||
|
Model = test()
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
from sqlalchemy import (
|
|
||||||
Column,
|
|
||||||
Integer,
|
|
||||||
Text,
|
|
||||||
DateTime,
|
|
||||||
Table,
|
|
||||||
ForeignKeyConstraint,
|
|
||||||
)
|
|
||||||
from sqlalchemy.orm import relationship, declarative_base
|
|
||||||
from sqlalchemy.sql import func
|
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
def to_dict(obj):
|
|
||||||
fields = obj.__table__.columns.keys()
|
|
||||||
return {
|
|
||||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
|
||||||
for field in fields
|
|
||||||
if (val := getattr(obj, field))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
ModelTag = Table(
|
|
||||||
"model_tag",
|
|
||||||
Base.metadata,
|
|
||||||
Column(
|
|
||||||
"model_type",
|
|
||||||
Text,
|
|
||||||
primary_key=True,
|
|
||||||
),
|
|
||||||
Column(
|
|
||||||
"model_path",
|
|
||||||
Text,
|
|
||||||
primary_key=True,
|
|
||||||
),
|
|
||||||
Column("tag_id", Integer, primary_key=True),
|
|
||||||
ForeignKeyConstraint(
|
|
||||||
["model_type", "model_path"], ["model.type", "model.path"], ondelete="CASCADE"
|
|
||||||
),
|
|
||||||
ForeignKeyConstraint(["tag_id"], ["tag.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Model(Base):
|
|
||||||
__tablename__ = "model"
|
|
||||||
|
|
||||||
type = Column(Text, primary_key=True)
|
|
||||||
path = Column(Text, primary_key=True)
|
|
||||||
title = Column(Text)
|
|
||||||
description = Column(Text)
|
|
||||||
architecture = Column(Text)
|
|
||||||
hash = Column(Text)
|
|
||||||
source_url = Column(Text)
|
|
||||||
date_added = Column(DateTime, server_default=func.now())
|
|
||||||
|
|
||||||
# Relationship with tags
|
|
||||||
tags = relationship("Tag", secondary=ModelTag, back_populates="models")
|
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
dict = to_dict(self)
|
|
||||||
dict["tags"] = [tag.to_dict() for tag in self.tags]
|
|
||||||
return dict
|
|
||||||
|
|
||||||
|
|
||||||
class Tag(Base):
|
|
||||||
__tablename__ = "tag"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
name = Column(Text, nullable=False, unique=True)
|
|
||||||
|
|
||||||
# Relationship with models
|
|
||||||
models = relationship("Model", secondary=ModelTag, back_populates="tags")
|
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
return to_dict(self)
|
|
||||||
32
app/database/routes.py
Normal file
32
app/database/routes.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from app.database.db import db
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
def create_routes(
|
||||||
|
routes, prefix, entity, get=False, get_by_id=False, post=False, delete=False
|
||||||
|
):
|
||||||
|
if get:
|
||||||
|
@routes.get(f"/{prefix}/{table}")
|
||||||
|
async def get_table(request):
|
||||||
|
connection = db.get_connection()
|
||||||
|
cursor = connection.cursor()
|
||||||
|
cursor.execute(f"SELECT * FROM {table}")
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
return web.json_response(rows)
|
||||||
|
|
||||||
|
if get_by_id:
|
||||||
|
@routes.get(f"/{prefix}/{table}/{id}")
|
||||||
|
async def get_table_by_id(request):
|
||||||
|
connection = db.get_connection()
|
||||||
|
cursor = connection.cursor()
|
||||||
|
cursor.execute(f"SELECT * FROM {table} WHERE id = {id}")
|
||||||
|
row = cursor.fetchone()
|
||||||
|
return web.json_response(row)
|
||||||
|
|
||||||
|
if post:
|
||||||
|
@routes.post(f"/{prefix}/{table}")
|
||||||
|
async def post_table(request):
|
||||||
|
data = await request.json()
|
||||||
|
connection = db.get_connection()
|
||||||
|
cursor = connection.cursor()
|
||||||
|
cursor.execute(f"INSERT INTO {table} ({data}) VALUES ({data})")
|
||||||
|
return web.json_response({"status": "success"})
|
||||||
79
app/database/updater.py
Normal file
79
app/database/updater.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from app.database.versions.v1 import v1
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseUpdater:
|
||||||
|
def __init__(self, connection, database_path):
|
||||||
|
self.connection = connection
|
||||||
|
self.database_path = database_path
|
||||||
|
self.current_version = self.get_db_version()
|
||||||
|
self.version_updates = {
|
||||||
|
1: v1,
|
||||||
|
}
|
||||||
|
self.max_version = max(self.version_updates.keys())
|
||||||
|
self.update_required = self.current_version < self.max_version
|
||||||
|
logging.info(f"Database version: {self.current_version}")
|
||||||
|
|
||||||
|
def get_db_version(self):
|
||||||
|
return self.connection.execute("PRAGMA user_version").fetchone()[0]
|
||||||
|
|
||||||
|
def backup(self):
|
||||||
|
bkp_path = self.database_path + ".bkp"
|
||||||
|
if os.path.exists(bkp_path):
|
||||||
|
# TODO: auto-rollback failed upgrades
|
||||||
|
raise Exception(
|
||||||
|
f"Database backup already exists, this indicates that a previous upgrade failed. Please restore this backup before continuing. Backup location: {bkp_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
bkp = sqlite3.connect(bkp_path)
|
||||||
|
self.connection.backup(bkp)
|
||||||
|
bkp.close()
|
||||||
|
logging.info("Database backup taken pre-upgrade.")
|
||||||
|
return bkp_path
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
if not self.update_required:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bkp_version = self.current_version
|
||||||
|
bkp_path = None
|
||||||
|
if self.current_version > 0:
|
||||||
|
bkp_path = self.backup()
|
||||||
|
|
||||||
|
logging.info(f"Updating database: {self.current_version} -> {self.max_version}")
|
||||||
|
|
||||||
|
dirname = os.path.dirname(__file__)
|
||||||
|
cursor = self.connection.cursor()
|
||||||
|
for version in range(self.current_version + 1, self.max_version + 1):
|
||||||
|
filename = os.path.join(dirname, f"versions/v{version}.sql")
|
||||||
|
if not os.path.exists(filename):
|
||||||
|
raise Exception(
|
||||||
|
f"Database update script for version {version} not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(filename, "r") as file:
|
||||||
|
sql = file.read()
|
||||||
|
cursor.executescript(sql)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to execute update script for version {version}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
method = self.version_updates[version]
|
||||||
|
if method is not None:
|
||||||
|
method(cursor)
|
||||||
|
|
||||||
|
cursor.execute("PRAGMA user_version = %d" % self.max_version)
|
||||||
|
self.connection.commit()
|
||||||
|
cursor.close()
|
||||||
|
self.current_version = self.get_db_version()
|
||||||
|
|
||||||
|
if bkp_path:
|
||||||
|
# Keep a copy of the backup in case something goes wrong and we need to rollback
|
||||||
|
os.rename(bkp_path, self.database_path + f".v{bkp_version}.bkp")
|
||||||
|
logging.info(f"Upgrade to successful.")
|
||||||
|
|
||||||
|
return (bkp_version, self.current_version)
|
||||||
17
app/database/versions/v1.py
Normal file
17
app/database/versions/v1.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from folder_paths import folder_names_and_paths, get_filename_list, get_full_path
|
||||||
|
|
||||||
|
|
||||||
|
def v1(cursor):
|
||||||
|
print("Updating to v1")
|
||||||
|
for folder_name in folder_names_and_paths.keys():
|
||||||
|
if folder_name == "custom_nodes":
|
||||||
|
continue
|
||||||
|
|
||||||
|
files = get_filename_list(folder_name)
|
||||||
|
for file in files:
|
||||||
|
file_path = get_full_path(folder_name, file)
|
||||||
|
file_without_extension = file.rsplit(".", maxsplit=1)[0]
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT INTO models (path, name, type) VALUES (?, ?, ?)",
|
||||||
|
(file_path, file_without_extension, folder_name),
|
||||||
|
)
|
||||||
41
app/database/versions/v1.sql
Normal file
41
app/database/versions/v1.sql
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS
|
||||||
|
models (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
path TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
architecture TEXT,
|
||||||
|
type TEXT NOT NULL,
|
||||||
|
hash TEXT,
|
||||||
|
source_url TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS
|
||||||
|
tags (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL UNIQUE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS
|
||||||
|
model_tags (
|
||||||
|
model_id INTEGER NOT NULL,
|
||||||
|
tag_id INTEGER NOT NULL,
|
||||||
|
PRIMARY KEY (model_id, tag_id),
|
||||||
|
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (tag_id) REFERENCES tags (id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO
|
||||||
|
tags (name)
|
||||||
|
VALUES
|
||||||
|
('character'),
|
||||||
|
('style'),
|
||||||
|
('concept'),
|
||||||
|
('clothing'),
|
||||||
|
('poses'),
|
||||||
|
('background'),
|
||||||
|
('vehicle'),
|
||||||
|
('buildings'),
|
||||||
|
('objects'),
|
||||||
|
('animal'),
|
||||||
|
('action');
|
||||||
63
app/model_hasher.py
Normal file
63
app/model_hasher.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
|
||||||
|
class ModelHasher:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._thread = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._model_entity = None
|
||||||
|
|
||||||
|
def start(self, model_entity):
|
||||||
|
if args.disable_model_hashing:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._model_entity = model_entity
|
||||||
|
|
||||||
|
if self._thread is None:
|
||||||
|
# Lock to prevent multiple threads from starting
|
||||||
|
with self._lock:
|
||||||
|
if self._thread is None:
|
||||||
|
self._thread = threading.Thread(target=self._hash_models)
|
||||||
|
self._thread.daemon = True
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def _get_models(self):
|
||||||
|
models = self._model_entity.get("WHERE hash IS NULL")
|
||||||
|
return models
|
||||||
|
|
||||||
|
def _hash_model(self, model_path):
|
||||||
|
h = hashlib.sha256()
|
||||||
|
b = bytearray(128 * 1024)
|
||||||
|
mv = memoryview(b)
|
||||||
|
with open(model_path, "rb", buffering=0) as f:
|
||||||
|
while n := f.readinto(mv):
|
||||||
|
h.update(mv[:n])
|
||||||
|
hash = h.hexdigest()
|
||||||
|
return hash
|
||||||
|
|
||||||
|
def _hash_models(self):
|
||||||
|
while True:
|
||||||
|
models = self._get_models()
|
||||||
|
|
||||||
|
if len(models) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
time.sleep(0)
|
||||||
|
now = time.time()
|
||||||
|
logging.info(f"Hashing model {model['path']}")
|
||||||
|
hash = self._hash_model(model["path"])
|
||||||
|
logging.info(
|
||||||
|
f"Hashed model {model['path']} in {time.time() - now} seconds"
|
||||||
|
)
|
||||||
|
self._model_entity.update((model["id"],), {"hash": hash})
|
||||||
|
|
||||||
|
self._thread = None
|
||||||
|
|
||||||
|
|
||||||
|
model_hasher = ModelHasher()
|
||||||
@@ -1,30 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
from app.database.db import create_session
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
import glob
|
||||||
|
import comfy.utils
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from folder_paths import map_legacy, filter_files_extensions, get_full_path
|
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
|
||||||
from app.database.models import Tag, Model
|
|
||||||
from app.model_processor import get_model_previews, model_processor
|
|
||||||
from utils.web import dumps
|
|
||||||
from sqlalchemy.orm import joinedload
|
|
||||||
import sqlalchemy.exc
|
|
||||||
|
|
||||||
|
|
||||||
def bad_request(message: str):
|
|
||||||
return web.json_response({"error": message}, status=400)
|
|
||||||
|
|
||||||
def missing_field(field: str):
|
|
||||||
return bad_request(f"{field} is required")
|
|
||||||
|
|
||||||
def not_found(message: str):
|
|
||||||
return web.json_response({"error": message + " not found"}, status=404)
|
|
||||||
|
|
||||||
class ModelFileManager:
|
class ModelFileManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
||||||
@@ -73,7 +62,7 @@ class ModelFileManager:
|
|||||||
folder = folders[0][path_index]
|
folder = folders[0][path_index]
|
||||||
full_filename = os.path.join(folder, filename)
|
full_filename = os.path.join(folder, filename)
|
||||||
|
|
||||||
previews = get_model_previews(full_filename)
|
previews = self.get_model_previews(full_filename)
|
||||||
default_preview = previews[0] if len(previews) > 0 else None
|
default_preview = previews[0] if len(previews) > 0 else None
|
||||||
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
@@ -87,183 +76,6 @@ class ModelFileManager:
|
|||||||
except:
|
except:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
@routes.get("/v2/models")
|
|
||||||
async def get_models(request):
|
|
||||||
with create_session() as session:
|
|
||||||
model_path = request.query.get("path", None)
|
|
||||||
model_type = request.query.get("type", None)
|
|
||||||
query = session.query(Model).options(joinedload(Model.tags))
|
|
||||||
if model_path:
|
|
||||||
query = query.filter(Model.path == model_path)
|
|
||||||
if model_type:
|
|
||||||
query = query.filter(Model.type == model_type)
|
|
||||||
models = query.all()
|
|
||||||
if model_path and model_type:
|
|
||||||
if len(models) == 0:
|
|
||||||
return not_found("Model")
|
|
||||||
return web.json_response(models[0].to_dict(), dumps=dumps)
|
|
||||||
|
|
||||||
return web.json_response([model.to_dict() for model in models], dumps=dumps)
|
|
||||||
|
|
||||||
@routes.post("/v2/models")
|
|
||||||
async def add_model(request):
|
|
||||||
with create_session() as session:
|
|
||||||
data = await request.json()
|
|
||||||
model_type = data.get("type", None)
|
|
||||||
model_path = data.get("path", None)
|
|
||||||
|
|
||||||
if not model_type:
|
|
||||||
return missing_field("type")
|
|
||||||
if not model_path:
|
|
||||||
return missing_field("path")
|
|
||||||
|
|
||||||
tags = data.pop("tags", [])
|
|
||||||
fields = Model.metadata.tables["model"].columns.keys()
|
|
||||||
|
|
||||||
# Validate keys are valid model fields
|
|
||||||
for key in data.keys():
|
|
||||||
if key not in fields:
|
|
||||||
return bad_request(f"Invalid field: {key}")
|
|
||||||
|
|
||||||
# Validate file exists
|
|
||||||
if not get_full_path(model_type, model_path):
|
|
||||||
return not_found(f"File '{model_type}/{model_path}'")
|
|
||||||
|
|
||||||
model = Model()
|
|
||||||
for field in fields:
|
|
||||||
if field in data:
|
|
||||||
setattr(model, field, data[field])
|
|
||||||
|
|
||||||
model.tags = session.query(Tag).filter(Tag.id.in_(tags)).all()
|
|
||||||
for tag in tags:
|
|
||||||
if tag not in [t.id for t in model.tags]:
|
|
||||||
return not_found(f"Tag '{tag}'")
|
|
||||||
|
|
||||||
try:
|
|
||||||
session.add(model)
|
|
||||||
session.commit()
|
|
||||||
except sqlalchemy.exc.IntegrityError as e:
|
|
||||||
session.rollback()
|
|
||||||
return bad_request(e.orig.args[0])
|
|
||||||
|
|
||||||
model_processor.run()
|
|
||||||
|
|
||||||
return web.json_response(model.to_dict(), dumps=dumps)
|
|
||||||
|
|
||||||
@routes.delete("/v2/models")
|
|
||||||
async def delete_model(request):
|
|
||||||
with create_session() as session:
|
|
||||||
model_path = request.query.get("path", None)
|
|
||||||
model_type = request.query.get("type", None)
|
|
||||||
if not model_path:
|
|
||||||
return missing_field("path")
|
|
||||||
if not model_type:
|
|
||||||
return missing_field("type")
|
|
||||||
|
|
||||||
full_path = get_full_path(model_type, model_path)
|
|
||||||
if full_path:
|
|
||||||
return bad_request("Model file exists, please delete the file before deleting the model record.")
|
|
||||||
|
|
||||||
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
|
|
||||||
if not model:
|
|
||||||
return not_found("Model")
|
|
||||||
session.delete(model)
|
|
||||||
session.commit()
|
|
||||||
return web.Response(status=204)
|
|
||||||
|
|
||||||
@routes.get("/v2/tags")
|
|
||||||
async def get_tags(request):
|
|
||||||
with create_session() as session:
|
|
||||||
tags = session.query(Tag).all()
|
|
||||||
return web.json_response(
|
|
||||||
[{"id": tag.id, "name": tag.name} for tag in tags]
|
|
||||||
)
|
|
||||||
|
|
||||||
@routes.post("/v2/tags")
|
|
||||||
async def create_tag(request):
|
|
||||||
with create_session() as session:
|
|
||||||
data = await request.json()
|
|
||||||
name = data.get("name", None)
|
|
||||||
if not name:
|
|
||||||
return missing_field("name")
|
|
||||||
tag = Tag(name=name)
|
|
||||||
session.add(tag)
|
|
||||||
session.commit()
|
|
||||||
return web.json_response({"id": tag.id, "name": tag.name})
|
|
||||||
|
|
||||||
@routes.delete("/v2/tags")
|
|
||||||
async def delete_tag(request):
|
|
||||||
with create_session() as session:
|
|
||||||
tag_id = request.query.get("id", None)
|
|
||||||
if not tag_id:
|
|
||||||
return missing_field("id")
|
|
||||||
tag = session.query(Tag).filter(Tag.id == tag_id).first()
|
|
||||||
if not tag:
|
|
||||||
return not_found("Tag")
|
|
||||||
session.delete(tag)
|
|
||||||
session.commit()
|
|
||||||
return web.Response(status=204)
|
|
||||||
|
|
||||||
@routes.post("/v2/models/tags")
|
|
||||||
async def add_model_tag(request):
|
|
||||||
with create_session() as session:
|
|
||||||
data = await request.json()
|
|
||||||
tag_id = data.get("tag", None)
|
|
||||||
model_path = data.get("path", None)
|
|
||||||
model_type = data.get("type", None)
|
|
||||||
|
|
||||||
if tag_id is None:
|
|
||||||
return missing_field("tag")
|
|
||||||
if model_path is None:
|
|
||||||
return missing_field("path")
|
|
||||||
if model_type is None:
|
|
||||||
return missing_field("type")
|
|
||||||
|
|
||||||
try:
|
|
||||||
tag_id = int(tag_id)
|
|
||||||
except ValueError:
|
|
||||||
return bad_request("Invalid tag id")
|
|
||||||
|
|
||||||
tag = session.query(Tag).filter(Tag.id == tag_id).first()
|
|
||||||
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
|
|
||||||
if not model:
|
|
||||||
return not_found("Model")
|
|
||||||
model.tags.append(tag)
|
|
||||||
session.commit()
|
|
||||||
return web.json_response(model.to_dict(), dumps=dumps)
|
|
||||||
|
|
||||||
@routes.delete("/v2/models/tags")
|
|
||||||
async def delete_model_tag(request):
|
|
||||||
with create_session() as session:
|
|
||||||
tag_id = request.query.get("tag", None)
|
|
||||||
model_path = request.query.get("path", None)
|
|
||||||
model_type = request.query.get("type", None)
|
|
||||||
|
|
||||||
if tag_id is None:
|
|
||||||
return missing_field("tag")
|
|
||||||
if model_path is None:
|
|
||||||
return missing_field("path")
|
|
||||||
if model_type is None:
|
|
||||||
return missing_field("type")
|
|
||||||
|
|
||||||
try:
|
|
||||||
tag_id = int(tag_id)
|
|
||||||
except ValueError:
|
|
||||||
return bad_request("Invalid tag id")
|
|
||||||
|
|
||||||
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
|
|
||||||
if not model:
|
|
||||||
return not_found("Model")
|
|
||||||
model.tags = [tag for tag in model.tags if tag.id != tag_id]
|
|
||||||
session.commit()
|
|
||||||
return web.Response(status=204)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/v2/models/missing")
|
|
||||||
async def get_missing_models(request):
|
|
||||||
return web.json_response(model_processor.missing_models)
|
|
||||||
|
|
||||||
def get_model_file_list(self, folder_name: str):
|
def get_model_file_list(self, folder_name: str):
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||||
@@ -334,5 +146,39 @@ class ModelFileManager:
|
|||||||
|
|
||||||
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
||||||
|
|
||||||
|
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
||||||
|
dirname = os.path.dirname(filepath)
|
||||||
|
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
return []
|
||||||
|
|
||||||
|
basename = os.path.splitext(filepath)[0]
|
||||||
|
match_files = glob.glob(f"{basename}.*", recursive=False)
|
||||||
|
image_files = filter_files_content_types(match_files, "image")
|
||||||
|
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
|
||||||
|
safetensors_metadata = {}
|
||||||
|
|
||||||
|
result: list[str | BytesIO] = []
|
||||||
|
|
||||||
|
for filename in image_files:
|
||||||
|
_basename = os.path.splitext(filename)[0]
|
||||||
|
if _basename == basename:
|
||||||
|
result.append(filename)
|
||||||
|
if _basename == f"{basename}.preview":
|
||||||
|
result.append(filename)
|
||||||
|
|
||||||
|
if safetensors_file:
|
||||||
|
safetensors_filepath = os.path.join(dirname, safetensors_file)
|
||||||
|
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
|
||||||
|
if header:
|
||||||
|
safetensors_metadata = json.loads(header)
|
||||||
|
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
|
||||||
|
if safetensors_images:
|
||||||
|
safetensors_images = json.loads(safetensors_images)
|
||||||
|
for image in safetensors_images:
|
||||||
|
result.append(BytesIO(base64.b64decode(image)))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
|
|||||||
@@ -1,263 +0,0 @@
|
|||||||
import base64
|
|
||||||
from datetime import datetime
|
|
||||||
import glob
|
|
||||||
import hashlib
|
|
||||||
from io import BytesIO
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import comfy.utils
|
|
||||||
from app.database.models import Model
|
|
||||||
from app.database.db import create_session
|
|
||||||
from comfy.cli_args import args
|
|
||||||
from folder_paths import (
|
|
||||||
filter_files_content_types,
|
|
||||||
get_full_path,
|
|
||||||
folder_names_and_paths,
|
|
||||||
get_filename_list,
|
|
||||||
)
|
|
||||||
from PIL import Image
|
|
||||||
from urllib import request
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_previews(
|
|
||||||
filepath: str, check_metadata: bool = True
|
|
||||||
) -> list[str | BytesIO]:
|
|
||||||
dirname = os.path.dirname(filepath)
|
|
||||||
|
|
||||||
if not os.path.exists(dirname):
|
|
||||||
return []
|
|
||||||
|
|
||||||
basename = os.path.splitext(filepath)[0]
|
|
||||||
match_files = glob.glob(f"{basename}.*", recursive=False)
|
|
||||||
image_files = filter_files_content_types(match_files, "image")
|
|
||||||
|
|
||||||
result: list[str | BytesIO] = []
|
|
||||||
|
|
||||||
for filename in image_files:
|
|
||||||
_basename = os.path.splitext(filename)[0]
|
|
||||||
if _basename == basename:
|
|
||||||
result.append(filename)
|
|
||||||
if _basename == f"{basename}.preview":
|
|
||||||
result.append(filename)
|
|
||||||
|
|
||||||
if not check_metadata:
|
|
||||||
return result
|
|
||||||
|
|
||||||
safetensors_file = next(
|
|
||||||
filter(lambda x: x.endswith(".safetensors"), match_files), None
|
|
||||||
)
|
|
||||||
safetensors_metadata = {}
|
|
||||||
|
|
||||||
if safetensors_file:
|
|
||||||
safetensors_filepath = os.path.join(dirname, safetensors_file)
|
|
||||||
header = comfy.utils.safetensors_header(
|
|
||||||
safetensors_filepath, max_size=8 * 1024 * 1024
|
|
||||||
)
|
|
||||||
if header:
|
|
||||||
safetensors_metadata = json.loads(header)
|
|
||||||
safetensors_images = safetensors_metadata.get("__metadata__", {}).get(
|
|
||||||
"ssmd_cover_images", None
|
|
||||||
)
|
|
||||||
if safetensors_images:
|
|
||||||
safetensors_images = json.loads(safetensors_images)
|
|
||||||
for image in safetensors_images:
|
|
||||||
result.append(BytesIO(base64.b64decode(image)))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProcessor:
|
|
||||||
def __init__(self):
|
|
||||||
self._thread = None
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
self._run = False
|
|
||||||
self.missing_models = []
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
if args.disable_model_processing:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._thread is None:
|
|
||||||
# Lock to prevent multiple threads from starting
|
|
||||||
with self._lock:
|
|
||||||
self._run = True
|
|
||||||
if self._thread is None:
|
|
||||||
self._thread = threading.Thread(target=self._process_models)
|
|
||||||
self._thread.daemon = True
|
|
||||||
self._thread.start()
|
|
||||||
|
|
||||||
def populate_models(self, session):
|
|
||||||
# Ensure database state matches filesystem
|
|
||||||
|
|
||||||
existing_models = session.query(Model).all()
|
|
||||||
|
|
||||||
for folder_name in folder_names_and_paths.keys():
|
|
||||||
if folder_name == "custom_nodes" or folder_name == "configs":
|
|
||||||
continue
|
|
||||||
seen = set()
|
|
||||||
files = get_filename_list(folder_name)
|
|
||||||
|
|
||||||
for file in files:
|
|
||||||
if file in seen:
|
|
||||||
logging.warning(f"Skipping duplicate named model: {file}")
|
|
||||||
continue
|
|
||||||
seen.add(file)
|
|
||||||
|
|
||||||
existing_model = None
|
|
||||||
for model in existing_models:
|
|
||||||
if model.path == file and model.type == folder_name:
|
|
||||||
existing_model = model
|
|
||||||
break
|
|
||||||
|
|
||||||
if existing_model:
|
|
||||||
# Model already exists in db, remove from list and skip
|
|
||||||
existing_models.remove(existing_model)
|
|
||||||
continue
|
|
||||||
|
|
||||||
file_path = get_full_path(folder_name, file)
|
|
||||||
|
|
||||||
model = Model(
|
|
||||||
path=file,
|
|
||||||
type=folder_name,
|
|
||||||
date_added=datetime.fromtimestamp(os.path.getctime(file_path)),
|
|
||||||
)
|
|
||||||
session.add(model)
|
|
||||||
|
|
||||||
for model in existing_models:
|
|
||||||
if not get_full_path(model.type, model.path):
|
|
||||||
logging.warning(f"Model {model.path} not found")
|
|
||||||
self.missing_models.append({"type": model.type, "path": model.path})
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
def _get_models(self, session):
|
|
||||||
models = session.query(Model).filter(Model.hash == None).all()
|
|
||||||
return models
|
|
||||||
|
|
||||||
def _process_file(self, model_path):
|
|
||||||
is_safetensors = model_path.endswith(".safetensors")
|
|
||||||
metadata = {}
|
|
||||||
h = hashlib.sha256()
|
|
||||||
|
|
||||||
with open(model_path, "rb", buffering=0) as f:
|
|
||||||
if is_safetensors:
|
|
||||||
# Read header length (8 bytes)
|
|
||||||
header_size_bytes = f.read(8)
|
|
||||||
header_len = int.from_bytes(header_size_bytes, "little")
|
|
||||||
h.update(header_size_bytes)
|
|
||||||
|
|
||||||
# Read header
|
|
||||||
header_bytes = f.read(header_len)
|
|
||||||
h.update(header_bytes)
|
|
||||||
try:
|
|
||||||
metadata = json.loads(header_bytes)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Read rest of file
|
|
||||||
b = bytearray(128 * 1024)
|
|
||||||
mv = memoryview(b)
|
|
||||||
while n := f.readinto(mv):
|
|
||||||
h.update(mv[:n])
|
|
||||||
|
|
||||||
return h.hexdigest(), metadata
|
|
||||||
|
|
||||||
def _populate_info(self, model, metadata):
|
|
||||||
model.title = metadata.get("modelspec.title", None)
|
|
||||||
model.description = metadata.get("modelspec.description", None)
|
|
||||||
model.architecture = metadata.get("modelspec.architecture", None)
|
|
||||||
|
|
||||||
def _extract_image(self, model_path, metadata):
|
|
||||||
# check if image already exists
|
|
||||||
if len(get_model_previews(model_path, check_metadata=False)) > 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
image_path = os.path.splitext(model_path)[0] + ".webp"
|
|
||||||
if os.path.exists(image_path):
|
|
||||||
return
|
|
||||||
|
|
||||||
cover_images = metadata.get("ssmd_cover_images", None)
|
|
||||||
image = None
|
|
||||||
if cover_images:
|
|
||||||
try:
|
|
||||||
cover_images = json.loads(cover_images)
|
|
||||||
if len(cover_images) > 0:
|
|
||||||
image_data = cover_images[0]
|
|
||||||
image = Image.open(BytesIO(base64.b64decode(image_data)))
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(
|
|
||||||
f"Error extracting cover image for model {model_path}: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not image:
|
|
||||||
thumbnail = metadata.get("modelspec.thumbnail", None)
|
|
||||||
if thumbnail:
|
|
||||||
try:
|
|
||||||
response = request.urlopen(thumbnail)
|
|
||||||
image = Image.open(response)
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(
|
|
||||||
f"Error extracting thumbnail for model {model_path}: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if image:
|
|
||||||
image.thumbnail((512, 512))
|
|
||||||
image.save(image_path)
|
|
||||||
image.close()
|
|
||||||
|
|
||||||
def _process_models(self):
|
|
||||||
with create_session() as session:
|
|
||||||
checked = set()
|
|
||||||
self.populate_models(session)
|
|
||||||
|
|
||||||
while self._run:
|
|
||||||
self._run = False
|
|
||||||
|
|
||||||
models = self._get_models(session)
|
|
||||||
|
|
||||||
if len(models) == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
for model in models:
|
|
||||||
# prevent looping on the same model if it crashes
|
|
||||||
if model.path in checked:
|
|
||||||
continue
|
|
||||||
|
|
||||||
checked.add(model.path)
|
|
||||||
|
|
||||||
try:
|
|
||||||
time.sleep(0)
|
|
||||||
now = time.time()
|
|
||||||
model_path = get_full_path(model.type, model.path)
|
|
||||||
|
|
||||||
if not model_path:
|
|
||||||
logging.warning(f"Model {model.path} not found")
|
|
||||||
self.missing_models.append(model.path)
|
|
||||||
continue
|
|
||||||
|
|
||||||
logging.debug(f"Processing model {model_path}")
|
|
||||||
hash, header = self._process_file(model_path)
|
|
||||||
logging.debug(
|
|
||||||
f"Processed model {model_path} in {time.time() - now} seconds"
|
|
||||||
)
|
|
||||||
model.hash = hash
|
|
||||||
|
|
||||||
if header:
|
|
||||||
metadata = header.get("__metadata__", None)
|
|
||||||
|
|
||||||
if metadata:
|
|
||||||
self._populate_info(model, metadata)
|
|
||||||
self._extract_image(model_path, metadata)
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error processing model {model.path}: {e}")
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
self._thread = None
|
|
||||||
|
|
||||||
|
|
||||||
model_processor = ModelProcessor()
|
|
||||||
@@ -143,9 +143,13 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
|
|||||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||||
|
|
||||||
|
parser.add_argument("--memory-database", default=False, action="store_true", help="Use an in-memory database instead of a file-based one.")
|
||||||
|
parser.add_argument("--disable-model-hashing", action="store_true", help="Disable model hashing.")
|
||||||
|
|
||||||
# The default built-in provider hosted under web/
|
# The default built-in provider hosted under web/
|
||||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--front-end-version",
|
"--front-end-version",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -178,12 +182,6 @@ parser.add_argument(
|
|||||||
|
|
||||||
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
||||||
|
|
||||||
database_default_path = os.path.abspath(
|
|
||||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
|
||||||
)
|
|
||||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
|
||||||
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
|
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
else:
|
else:
|
||||||
|
|||||||
11
main.py
11
main.py
@@ -138,8 +138,6 @@ import server
|
|||||||
from server import BinaryEventTypes
|
from server import BinaryEventTypes
|
||||||
import nodes
|
import nodes
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from app.database.db import can_create_session, init_db
|
|
||||||
from app.model_processor import model_processor
|
|
||||||
|
|
||||||
def cuda_malloc_warning():
|
def cuda_malloc_warning():
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
@@ -264,11 +262,6 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
|
|
||||||
cuda_malloc_warning()
|
cuda_malloc_warning()
|
||||||
|
|
||||||
try:
|
|
||||||
init_db()
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to initialize database. Please report this error as in future the database will be required: {e}")
|
|
||||||
|
|
||||||
prompt_server.add_routes()
|
prompt_server.add_routes()
|
||||||
hijack_progress(prompt_server)
|
hijack_progress(prompt_server)
|
||||||
|
|
||||||
@@ -277,10 +270,6 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
# Scan for changed model files and update db
|
|
||||||
if can_create_session():
|
|
||||||
model_processor.run()
|
|
||||||
|
|
||||||
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
||||||
call_on_start = None
|
call_on_start = None
|
||||||
if args.auto_launch:
|
if args.auto_launch:
|
||||||
|
|||||||
@@ -13,8 +13,6 @@ Pillow
|
|||||||
scipy
|
scipy
|
||||||
tqdm
|
tqdm
|
||||||
psutil
|
psutil
|
||||||
alembic
|
|
||||||
SQLAlchemy
|
|
||||||
|
|
||||||
#non essential dependencies:
|
#non essential dependencies:
|
||||||
kornia>=0.7.1
|
kornia>=0.7.1
|
||||||
|
|||||||
17
server.py
17
server.py
@@ -34,6 +34,9 @@ from app.model_manager import ModelFileManager
|
|||||||
from app.custom_node_manager import CustomNodeManager
|
from app.custom_node_manager import CustomNodeManager
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
|
from app.database.entities import get_entity, init_entities
|
||||||
|
from app.database.db import db
|
||||||
|
from app.model_hasher import model_hasher
|
||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
@@ -682,11 +685,25 @@ class PromptServer():
|
|||||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
|
||||||
|
def init_db(self, routes):
|
||||||
|
init_entities(routes)
|
||||||
|
models = get_entity("models")
|
||||||
|
|
||||||
|
if db.exists:
|
||||||
|
model_hasher.start(models)
|
||||||
|
else:
|
||||||
|
def on_db_update(_, __):
|
||||||
|
model_hasher.start(models)
|
||||||
|
|
||||||
|
db.register_after_update_callback(on_db_update)
|
||||||
|
|
||||||
|
|
||||||
def add_routes(self):
|
def add_routes(self):
|
||||||
self.user_manager.add_routes(self.routes)
|
self.user_manager.add_routes(self.routes)
|
||||||
self.model_file_manager.add_routes(self.routes)
|
self.model_file_manager.add_routes(self.routes)
|
||||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||||
|
self.init_db(self.routes)
|
||||||
|
|
||||||
# Prefix every route with /api for easier matching for delegation.
|
# Prefix every route with /api for easier matching for delegation.
|
||||||
# This is very useful for frontend dev server, which need to forward
|
# This is very useful for frontend dev server, which need to forward
|
||||||
|
|||||||
513
tests-unit/app_test/entities_test.py
Normal file
513
tests-unit/app_test/entities_test.py
Normal file
@@ -0,0 +1,513 @@
|
|||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
args.memory_database = True # force in-memory database for testing
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from unittest.mock import patch
|
||||||
|
from aiohttp import web
|
||||||
|
from app.database.entities import (
|
||||||
|
DeleteEntity,
|
||||||
|
column,
|
||||||
|
table,
|
||||||
|
Column,
|
||||||
|
GetEntity,
|
||||||
|
GetEntityById,
|
||||||
|
CreateEntity,
|
||||||
|
UpsertEntity,
|
||||||
|
UpdateEntity,
|
||||||
|
)
|
||||||
|
from app.database.db import db
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
def create_table(entity):
|
||||||
|
# reset db
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
cols: list[Column] = entity._columns
|
||||||
|
# Create tables as temporary so when we close the db, the tables are dropped for next test
|
||||||
|
sql = f"CREATE TEMPORARY TABLE {entity._table_name} ( "
|
||||||
|
for col_name, col in cols.items():
|
||||||
|
type = None
|
||||||
|
if col.type == int:
|
||||||
|
type = "INTEGER"
|
||||||
|
elif col.type == str:
|
||||||
|
type = "TEXT"
|
||||||
|
|
||||||
|
sql += f"{col_name} {type}"
|
||||||
|
if col.required:
|
||||||
|
sql += " NOT NULL"
|
||||||
|
sql += ", "
|
||||||
|
|
||||||
|
sql += f"PRIMARY KEY ({', '.join(entity._key_columns)})"
|
||||||
|
sql += ")"
|
||||||
|
db.execute(sql)
|
||||||
|
|
||||||
|
|
||||||
|
async def wrap_db(method: Callable, expected_sql: str, expected_args: list):
|
||||||
|
with patch.object(db, "execute", wraps=db.execute) as mock:
|
||||||
|
response = await method()
|
||||||
|
assert mock.call_count == 1
|
||||||
|
assert mock.call_args[0][0] == expected_sql
|
||||||
|
assert mock.call_args[0][1:] == expected_args
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def getable_entity():
|
||||||
|
@table("getable_entity")
|
||||||
|
class GetableEntity(GetEntity):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
test: str = column(str, required=True)
|
||||||
|
nullable: Optional[str] = column(str)
|
||||||
|
|
||||||
|
return GetableEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def getable_by_id_entity():
|
||||||
|
@table("getable_by_id_entity")
|
||||||
|
class GetableByIdEntity(GetEntityById):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
test: str = column(str, required=True)
|
||||||
|
|
||||||
|
return GetableByIdEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def getable_by_id_composite_entity():
|
||||||
|
@table("getable_by_id_composite_entity")
|
||||||
|
class GetableByIdCompositeEntity(GetEntityById):
|
||||||
|
id1: str = column(str, required=True, key=True)
|
||||||
|
id2: int = column(int, required=True, key=True)
|
||||||
|
test: str = column(str, required=True)
|
||||||
|
|
||||||
|
return GetableByIdCompositeEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def creatable_entity():
|
||||||
|
@table("creatable_entity")
|
||||||
|
class CreatableEntity(CreateEntity):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
test: str = column(str, required=True)
|
||||||
|
reqd: str = column(str, required=True)
|
||||||
|
nullable: Optional[str] = column(str)
|
||||||
|
|
||||||
|
return CreatableEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def upsertable_entity():
|
||||||
|
@table("upsertable_entity")
|
||||||
|
class UpsertableEntity(UpsertEntity):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
test: str = column(str, required=True)
|
||||||
|
reqd: str = column(str, required=True)
|
||||||
|
nullable: Optional[str] = column(str)
|
||||||
|
|
||||||
|
return UpsertableEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def updateable_entity():
|
||||||
|
@table("updateable_entity")
|
||||||
|
class UpdateableEntity(UpdateEntity):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
reqd: str = column(str, required=True)
|
||||||
|
|
||||||
|
return UpdateableEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def deletable_entity():
|
||||||
|
@table("deletable_entity")
|
||||||
|
class DeletableEntity(DeleteEntity):
|
||||||
|
id: int = column(int, required=True, key=True)
|
||||||
|
|
||||||
|
return DeletableEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def deletable_composite_entity():
|
||||||
|
@table("deletable_composite_entity")
|
||||||
|
class DeletableCompositeEntity(DeleteEntity):
|
||||||
|
id1: str = column(str, required=True, key=True)
|
||||||
|
id2: int = column(int, required=True, key=True)
|
||||||
|
|
||||||
|
return DeletableCompositeEntity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def entity(request):
|
||||||
|
value = request.getfixturevalue(request.param)
|
||||||
|
create_table(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(aiohttp_client, app):
|
||||||
|
return await aiohttp_client(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(entity):
|
||||||
|
app = web.Application()
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
entity.register_route(routes)
|
||||||
|
app.add_routes(routes)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||||
|
async def test_get_model_empty_response(client):
|
||||||
|
expected_sql = "SELECT * FROM getable_entity"
|
||||||
|
expected_args = ()
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||||
|
async def test_get_model_with_data(client):
|
||||||
|
# seed db
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2')"
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_sql = "SELECT * FROM getable_entity"
|
||||||
|
expected_args = ()
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == [
|
||||||
|
{"id": 1, "test": "test1", "nullable": None},
|
||||||
|
{"id": 2, "test": "test2", "nullable": "test2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||||
|
async def test_get_model_with_top_parameter(client):
|
||||||
|
# seed with 3 rows
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2'), (3, 'test3', 'test3')"
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_sql = "SELECT * FROM getable_entity LIMIT 2"
|
||||||
|
expected_args = ()
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.get("/db/getable_entity?top=2"),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == [
|
||||||
|
{"id": 1, "test": "test1", "nullable": None},
|
||||||
|
{"id": 2, "test": "test2", "nullable": "test2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||||
|
async def test_get_model_with_invalid_top_parameter(client):
|
||||||
|
response = await client.get("/db/getable_entity?top=hello")
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Invalid top parameter",
|
||||||
|
"field": "top",
|
||||||
|
"value": "hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
|
||||||
|
async def test_get_model_by_id_empty_response(client):
|
||||||
|
# seed db
|
||||||
|
db.execute("INSERT INTO getable_by_id_entity (id, test) VALUES (1, 'test1')")
|
||||||
|
|
||||||
|
expected_sql = "SELECT * FROM getable_by_id_entity WHERE id = ?"
|
||||||
|
expected_args = (1,)
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.get("/db/getable_by_id_entity/1"),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == [
|
||||||
|
{"id": 1, "test": "test1"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
|
||||||
|
async def test_get_model_by_id_with_invalid_id(client):
|
||||||
|
response = await client.get("/db/getable_by_id_entity/hello")
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Invalid value",
|
||||||
|
"field": "id",
|
||||||
|
"value": "hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
|
||||||
|
async def test_get_model_by_id_composite(client):
|
||||||
|
# seed db
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO getable_by_id_composite_entity (id1, id2, test) VALUES ('one', 2, 'test')"
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_sql = (
|
||||||
|
"SELECT * FROM getable_by_id_composite_entity WHERE id1 = ? AND id2 = ?"
|
||||||
|
)
|
||||||
|
expected_args = ("one", 2)
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.get("/db/getable_by_id_composite_entity/one/2"),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == [
|
||||||
|
{"id1": "one", "id2": 2, "test": "test"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
|
||||||
|
async def test_get_model_by_id_composite_with_invalid_id(client):
|
||||||
|
response = await client.get("/db/getable_by_id_composite_entity/hello/hello")
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Invalid value",
|
||||||
|
"field": "id2",
|
||||||
|
"value": "hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||||
|
async def test_create_model(client):
|
||||||
|
expected_sql = (
|
||||||
|
"INSERT INTO creatable_entity (id, test, reqd) VALUES (?, ?, ?) RETURNING *"
|
||||||
|
)
|
||||||
|
expected_args = (1, "test1", "reqd1")
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.post(
|
||||||
|
"/db/creatable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
|
||||||
|
),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == {
|
||||||
|
"id": 1,
|
||||||
|
"test": "test1",
|
||||||
|
"reqd": "reqd1",
|
||||||
|
"nullable": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||||
|
async def test_create_model_missing_required_field(client):
|
||||||
|
response = await client.post(
|
||||||
|
"/db/creatable_entity", json={"id": 1, "test": "test1"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Missing field",
|
||||||
|
"field": "reqd",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||||
|
async def test_create_model_missing_key_field(client):
|
||||||
|
response = await client.post(
|
||||||
|
"/db/creatable_entity",
|
||||||
|
json={"test": "test1", "reqd": "reqd1"}, # Missing 'id' which is a key
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Missing field",
|
||||||
|
"field": "id",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||||
|
async def test_create_model_invalid_key_data(client):
|
||||||
|
response = await client.post(
|
||||||
|
"/db/creatable_entity",
|
||||||
|
json={
|
||||||
|
"id": "not_an_integer",
|
||||||
|
"test": "test1",
|
||||||
|
"reqd": "reqd1",
|
||||||
|
}, # id should be int
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Invalid value",
|
||||||
|
"field": "id",
|
||||||
|
"value": "not_an_integer",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||||
|
async def test_create_model_invalid_field_data(client):
|
||||||
|
response = await client.post(
|
||||||
|
"/db/creatable_entity",
|
||||||
|
json={"id": "aaa", "test": "123", "reqd": "reqd1"}, # id should be int
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Invalid value",
|
||||||
|
"field": "id",
|
||||||
|
"value": "aaa",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||||
|
async def test_create_model_invalid_field_type(client):
|
||||||
|
response = await client.post(
|
||||||
|
"/db/creatable_entity",
|
||||||
|
json={
|
||||||
|
"id": 1,
|
||||||
|
"test": ["invalid_array"],
|
||||||
|
"reqd": "reqd1",
|
||||||
|
}, # test should be string
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Invalid value",
|
||||||
|
"field": "test",
|
||||||
|
"value": ["invalid_array"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||||
|
async def test_create_model_invalid_field_name(client):
|
||||||
|
response = await client.post(
|
||||||
|
"/db/creatable_entity",
|
||||||
|
json={"id": 1, "test": "test1", "reqd": "reqd1", "nonexistent_field": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Unknown field",
|
||||||
|
"field": "nonexistent_field",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["upsertable_entity"], indirect=True)
|
||||||
|
async def test_upsert_model(client):
|
||||||
|
expected_sql = (
|
||||||
|
"INSERT INTO upsertable_entity (id, test, reqd) VALUES (?, ?, ?) "
|
||||||
|
"ON CONFLICT (id) DO UPDATE SET test = excluded.test, reqd = excluded.reqd "
|
||||||
|
"RETURNING *"
|
||||||
|
)
|
||||||
|
expected_args = (1, "test1", "reqd1")
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.put(
|
||||||
|
"/db/upsertable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
|
||||||
|
),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == {
|
||||||
|
"id": 1,
|
||||||
|
"test": "test1",
|
||||||
|
"reqd": "reqd1",
|
||||||
|
"nullable": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||||
|
async def test_update_model(client):
|
||||||
|
# seed db
|
||||||
|
db.execute("INSERT INTO updateable_entity (id, reqd) VALUES (1, 'test1')")
|
||||||
|
|
||||||
|
expected_sql = "UPDATE updateable_entity SET reqd = ? WHERE id = ? RETURNING *"
|
||||||
|
expected_args = ("updated_test", 1)
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.patch("/db/updateable_entity/1", json={"reqd": "updated_test"}),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.json() == {
|
||||||
|
"id": 1,
|
||||||
|
"reqd": "updated_test",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||||
|
async def test_update_model_reject_null_required_field(client):
|
||||||
|
response = await client.patch("/db/updateable_entity/1", json={"reqd": None})
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Required field",
|
||||||
|
"field": "reqd",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||||
|
async def test_update_model_reject_invalid_field(client):
|
||||||
|
response = await client.patch("/db/updateable_entity/1", json={"hello": "world"})
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Unknown field",
|
||||||
|
"field": "hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||||
|
async def test_update_model_reject_missing_record(client):
|
||||||
|
response = await client.patch(
|
||||||
|
"/db/updateable_entity/1", json={"reqd": "updated_test"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 404
|
||||||
|
assert await response.json() == {
|
||||||
|
"message": "Failed to update entity",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["deletable_entity"], indirect=True)
|
||||||
|
async def test_delete_model(client):
|
||||||
|
expected_sql = "DELETE FROM deletable_entity WHERE id = ?"
|
||||||
|
expected_args = (1,)
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.delete("/db/deletable_entity/1"),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 204
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("entity", ["deletable_composite_entity"], indirect=True)
|
||||||
|
async def test_delete_model_composite_key(client):
|
||||||
|
expected_sql = "DELETE FROM deletable_composite_entity WHERE id1 = ? AND id2 = ?"
|
||||||
|
expected_args = ("one", 2)
|
||||||
|
response = await wrap_db(
|
||||||
|
lambda: client.delete("/db/deletable_composite_entity/one/2"),
|
||||||
|
expected_sql,
|
||||||
|
expected_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 204
|
||||||
@@ -7,33 +7,11 @@ from PIL import Image
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
from app.database.models import Base, Model, Tag
|
|
||||||
from comfy.cli_args import args
|
|
||||||
from sqlalchemy import create_engine
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
pytestmark = (
|
pytestmark = (
|
||||||
pytest.mark.asyncio
|
pytest.mark.asyncio
|
||||||
) # This applies the asyncio mark to all test functions in the module
|
) # This applies the asyncio mark to all test functions in the module
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def session():
|
|
||||||
# Configure in-memory database
|
|
||||||
args.database_url = "sqlite:///:memory:"
|
|
||||||
|
|
||||||
# Create engine and session factory
|
|
||||||
engine = create_engine(args.database_url)
|
|
||||||
Session = sessionmaker(bind=engine)
|
|
||||||
|
|
||||||
# Create all tables
|
|
||||||
Base.metadata.create_all(engine)
|
|
||||||
|
|
||||||
# Patch Session factory
|
|
||||||
with patch('app.database.db.Session', Session):
|
|
||||||
yield Session()
|
|
||||||
|
|
||||||
Base.metadata.drop_all(engine)
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def model_manager():
|
def model_manager():
|
||||||
return ModelFileManager()
|
return ModelFileManager()
|
||||||
@@ -82,287 +60,3 @@ async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
|
|||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
img.close()
|
img.close()
|
||||||
|
|
||||||
async def test_get_models(aiohttp_client, app, session):
|
|
||||||
tag = Tag(name='test_tag')
|
|
||||||
model = Model(
|
|
||||||
type='checkpoints',
|
|
||||||
path='model1.safetensors',
|
|
||||||
title='Test Model'
|
|
||||||
)
|
|
||||||
model.tags.append(tag)
|
|
||||||
session.add(tag)
|
|
||||||
session.add(model)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.get('/v2/models')
|
|
||||||
assert resp.status == 200
|
|
||||||
data = await resp.json()
|
|
||||||
assert len(data) == 1
|
|
||||||
assert data[0]['path'] == 'model1.safetensors'
|
|
||||||
assert len(data[0]['tags']) == 1
|
|
||||||
assert data[0]['tags'][0]['name'] == 'test_tag'
|
|
||||||
|
|
||||||
async def test_add_model(aiohttp_client, app, session):
|
|
||||||
tag = Tag(name='test_tag')
|
|
||||||
session.add(tag)
|
|
||||||
session.commit()
|
|
||||||
tag_id = tag.id
|
|
||||||
|
|
||||||
with patch('app.model_manager.model_processor') as mock_processor:
|
|
||||||
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.post('/v2/models', json={
|
|
||||||
'type': 'checkpoints',
|
|
||||||
'path': 'model1.safetensors',
|
|
||||||
'title': 'Test Model',
|
|
||||||
'tags': [tag_id]
|
|
||||||
})
|
|
||||||
|
|
||||||
assert resp.status == 200
|
|
||||||
data = await resp.json()
|
|
||||||
assert data['path'] == 'model1.safetensors'
|
|
||||||
assert len(data['tags']) == 1
|
|
||||||
assert data['tags'][0]['name'] == 'test_tag'
|
|
||||||
|
|
||||||
# Ensure that models are re-processed after adding
|
|
||||||
mock_processor.run.assert_called_once()
|
|
||||||
|
|
||||||
async def test_delete_model(aiohttp_client, app, session):
|
|
||||||
model = Model(
|
|
||||||
type='checkpoints',
|
|
||||||
path='model1.safetensors',
|
|
||||||
title='Test Model'
|
|
||||||
)
|
|
||||||
session.add(model)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
with patch('app.model_manager.get_full_path', return_value=None):
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
|
|
||||||
assert resp.status == 204
|
|
||||||
|
|
||||||
# Verify model was deleted
|
|
||||||
model = session.query(Model).first()
|
|
||||||
assert model is None
|
|
||||||
|
|
||||||
async def test_delete_model_file_exists(aiohttp_client, app, session):
|
|
||||||
model = Model(
|
|
||||||
type='checkpoints',
|
|
||||||
path='model1.safetensors',
|
|
||||||
title='Test Model'
|
|
||||||
)
|
|
||||||
session.add(model)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.delete('/v2/models?type=checkpoints&path=model1.safetensors')
|
|
||||||
assert resp.status == 400
|
|
||||||
|
|
||||||
data = await resp.json()
|
|
||||||
assert "file exists" in data["error"].lower()
|
|
||||||
|
|
||||||
# Verify model was not deleted
|
|
||||||
model = session.query(Model).first()
|
|
||||||
assert model is not None
|
|
||||||
assert model.path == 'model1.safetensors'
|
|
||||||
|
|
||||||
async def test_get_tags(aiohttp_client, app, session):
|
|
||||||
tags = [Tag(name='tag1'), Tag(name='tag2')]
|
|
||||||
for tag in tags:
|
|
||||||
session.add(tag)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.get('/v2/tags')
|
|
||||||
assert resp.status == 200
|
|
||||||
data = await resp.json()
|
|
||||||
assert len(data) == 2
|
|
||||||
assert {t['name'] for t in data} == {'tag1', 'tag2'}
|
|
||||||
|
|
||||||
async def test_create_tag(aiohttp_client, app, session):
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.post('/v2/tags', json={'name': 'new_tag'})
|
|
||||||
assert resp.status == 200
|
|
||||||
data = await resp.json()
|
|
||||||
assert data['name'] == 'new_tag'
|
|
||||||
|
|
||||||
# Verify tag was created
|
|
||||||
tag = session.query(Tag).first()
|
|
||||||
assert tag.name == 'new_tag'
|
|
||||||
|
|
||||||
async def test_delete_tag(aiohttp_client, app, session):
|
|
||||||
tag = Tag(name='test_tag')
|
|
||||||
session.add(tag)
|
|
||||||
session.commit()
|
|
||||||
tag_id = tag.id
|
|
||||||
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.delete(f'/v2/tags?id={tag_id}')
|
|
||||||
assert resp.status == 204
|
|
||||||
|
|
||||||
# Verify tag was deleted
|
|
||||||
tag = session.query(Tag).first()
|
|
||||||
assert tag is None
|
|
||||||
|
|
||||||
async def test_add_model_tag(aiohttp_client, app, session):
|
|
||||||
tag = Tag(name='test_tag')
|
|
||||||
model = Model(
|
|
||||||
type='checkpoints',
|
|
||||||
path='model1.safetensors',
|
|
||||||
title='Test Model'
|
|
||||||
)
|
|
||||||
session.add(tag)
|
|
||||||
session.add(model)
|
|
||||||
session.commit()
|
|
||||||
tag_id = tag.id
|
|
||||||
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.post('/v2/models/tags', json={
|
|
||||||
'tag': tag_id,
|
|
||||||
'type': 'checkpoints',
|
|
||||||
'path': 'model1.safetensors'
|
|
||||||
})
|
|
||||||
assert resp.status == 200
|
|
||||||
data = await resp.json()
|
|
||||||
assert len(data['tags']) == 1
|
|
||||||
assert data['tags'][0]['name'] == 'test_tag'
|
|
||||||
|
|
||||||
async def test_delete_model_tag(aiohttp_client, app, session):
|
|
||||||
tag = Tag(name='test_tag')
|
|
||||||
model = Model(
|
|
||||||
type='checkpoints',
|
|
||||||
path='model1.safetensors',
|
|
||||||
title='Test Model'
|
|
||||||
)
|
|
||||||
model.tags.append(tag)
|
|
||||||
session.add(tag)
|
|
||||||
session.add(model)
|
|
||||||
session.commit()
|
|
||||||
tag_id = tag.id
|
|
||||||
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.delete(f'/v2/models/tags?tag={tag_id}&type=checkpoints&path=model1.safetensors')
|
|
||||||
assert resp.status == 204
|
|
||||||
|
|
||||||
# Verify tag was removed
|
|
||||||
model = session.query(Model).first()
|
|
||||||
assert len(model.tags) == 0
|
|
||||||
|
|
||||||
async def test_add_model_duplicate(aiohttp_client, app, session):
|
|
||||||
model = Model(
|
|
||||||
type='checkpoints',
|
|
||||||
path='model1.safetensors',
|
|
||||||
title='Test Model'
|
|
||||||
)
|
|
||||||
session.add(model)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.post('/v2/models', json={
|
|
||||||
'type': 'checkpoints',
|
|
||||||
'path': 'model1.safetensors',
|
|
||||||
'title': 'Duplicate Model'
|
|
||||||
})
|
|
||||||
assert resp.status == 400
|
|
||||||
|
|
||||||
async def test_add_model_missing_fields(aiohttp_client, app, session):
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.post('/v2/models', json={})
|
|
||||||
assert resp.status == 400
|
|
||||||
|
|
||||||
async def test_add_tag_missing_name(aiohttp_client, app, session):
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.post('/v2/tags', json={})
|
|
||||||
assert resp.status == 400
|
|
||||||
|
|
||||||
async def test_delete_model_not_found(aiohttp_client, app, session):
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.delete('/v2/models?type=checkpoints&path=nonexistent.safetensors')
|
|
||||||
assert resp.status == 404
|
|
||||||
|
|
||||||
async def test_delete_tag_not_found(aiohttp_client, app, session):
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.delete('/v2/tags?id=999')
|
|
||||||
assert resp.status == 404
|
|
||||||
|
|
||||||
async def test_add_model_missing_path(aiohttp_client, app, session):
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.post('/v2/models', json={
|
|
||||||
'type': 'checkpoints',
|
|
||||||
'title': 'Test Model'
|
|
||||||
})
|
|
||||||
assert resp.status == 400
|
|
||||||
data = await resp.json()
|
|
||||||
assert "path" in data["error"].lower()
|
|
||||||
|
|
||||||
async def test_add_model_invalid_field(aiohttp_client, app, session):
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.post('/v2/models', json={
|
|
||||||
'type': 'checkpoints',
|
|
||||||
'path': 'model1.safetensors',
|
|
||||||
'invalid_field': 'some value'
|
|
||||||
})
|
|
||||||
assert resp.status == 400
|
|
||||||
data = await resp.json()
|
|
||||||
assert "invalid field" in data["error"].lower()
|
|
||||||
|
|
||||||
async def test_add_model_nonexistent_file(aiohttp_client, app, session):
|
|
||||||
with patch('app.model_manager.get_full_path', return_value=None):
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.post('/v2/models', json={
|
|
||||||
'type': 'checkpoints',
|
|
||||||
'path': 'nonexistent.safetensors'
|
|
||||||
})
|
|
||||||
assert resp.status == 404
|
|
||||||
data = await resp.json()
|
|
||||||
assert "file" in data["error"].lower()
|
|
||||||
|
|
||||||
async def test_add_model_invalid_tag(aiohttp_client, app, session):
|
|
||||||
with patch('app.model_manager.get_full_path', return_value='/checkpoints/model1.safetensors'):
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.post('/v2/models', json={
|
|
||||||
'type': 'checkpoints',
|
|
||||||
'path': 'model1.safetensors',
|
|
||||||
'tags': [999] # Non-existent tag ID
|
|
||||||
})
|
|
||||||
assert resp.status == 404
|
|
||||||
data = await resp.json()
|
|
||||||
assert "tag" in data["error"].lower()
|
|
||||||
|
|
||||||
async def test_add_tag_to_nonexistent_model(aiohttp_client, app, session):
|
|
||||||
# Create a tag but no model
|
|
||||||
tag = Tag(name='test_tag')
|
|
||||||
session.add(tag)
|
|
||||||
session.commit()
|
|
||||||
tag_id = tag.id
|
|
||||||
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.post('/v2/models/tags', json={
|
|
||||||
'tag': tag_id,
|
|
||||||
'type': 'checkpoints',
|
|
||||||
'path': 'nonexistent.safetensors'
|
|
||||||
})
|
|
||||||
assert resp.status == 404
|
|
||||||
data = await resp.json()
|
|
||||||
assert "model" in data["error"].lower()
|
|
||||||
|
|
||||||
async def test_delete_model_tag_invalid_tag_id(aiohttp_client, app, session):
|
|
||||||
# Create a model first
|
|
||||||
model = Model(
|
|
||||||
type='checkpoints',
|
|
||||||
path='model1.safetensors',
|
|
||||||
title='Test Model'
|
|
||||||
)
|
|
||||||
session.add(model)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
resp = await client.delete('/v2/models/tags?tag=not_a_number&type=checkpoint&path=model1.safetensors')
|
|
||||||
assert resp.status == 400
|
|
||||||
data = await resp.json()
|
|
||||||
assert "invalid tag id" in data["error"].lower()
|
|
||||||
|
|
||||||
|
|||||||
12
utils/web.py
12
utils/web.py
@@ -1,12 +0,0 @@
|
|||||||
import json
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
class DateTimeEncoder(json.JSONEncoder):
|
|
||||||
def default(self, obj):
|
|
||||||
if isinstance(obj, datetime):
|
|
||||||
return obj.isoformat()
|
|
||||||
return super().default(obj)
|
|
||||||
|
|
||||||
|
|
||||||
dumps = DateTimeEncoder().encode
|
|
||||||
Reference in New Issue
Block a user