diff --git a/app/database/__init__.py b/app/database/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/database.py b/app/database/db.py similarity index 69% rename from app/database.py rename to app/database/db.py index 6bbb14eb..9ed9c9ef 100644 --- a/app/database.py +++ b/app/database/db.py @@ -1,9 +1,10 @@ import logging +import os import sqlite3 from contextlib import contextmanager from queue import Queue, Empty, Full import threading -from app.database_updater import DatabaseUpdater +from app.database.updater import DatabaseUpdater import folder_paths from comfy.cli_args import args @@ -11,7 +12,10 @@ from comfy.cli_args import args class Database: def __init__(self, database_path=None, pool_size=1): if database_path is None: + self.exists = False database_path = "file::memory:?cache=shared" + else: + self.exists = os.path.exists(database_path) self.database_path = database_path self.pool_size = pool_size @@ -20,6 +24,7 @@ class Database: self._db_lock = threading.Lock() self._initialized = False self._closing = False + self._after_update_callbacks = [] def _setup(self): if self._initialized: @@ -33,14 +38,28 @@ class Database: def _create_connection(self): # TODO: Catch error for sqlite lib missing on linux logging.info(f"Creating connection to {self.database_path}") - conn = sqlite3.connect(self.database_path, check_same_thread=False) + conn = sqlite3.connect( + self.database_path, + check_same_thread=False, + uri=self.database_path.startswith("file::"), + ) conn.execute("PRAGMA foreign_keys = ON") + self.exists = True + logging.info(f"Connected!") return conn def _make_db(self): with self._get_connection() as con: - updater = DatabaseUpdater(con) - updater.update() + updater = DatabaseUpdater(con, self.database_path) + result = updater.update() + if result is not None: + old_version, new_version = result + + for callback in self._after_update_callbacks: + callback(old_version, new_version) + + def _transform(self, row, columns): + return {col.name: value for value, col in zip(row, columns)} @contextmanager def _get_connection(self): @@ -71,6 +90,15 @@ class Database: with self._get_connection() as connection: yield connection + def execute(self, sql, *args): + with self.get_connection() as connection: + cursor = connection.execute(sql, args) + results = cursor.fetchall() + return results + + def register_after_update_callback(self, callback): + self._after_update_callbacks.append(callback) + def close(self): if self._closing: return @@ -82,6 +110,7 @@ class Database: conn.close() except Empty: break + self._closing = False def __del__(self): try: diff --git a/app/database/entities.py b/app/database/entities.py new file mode 100644 index 00000000..c608ace3 --- /dev/null +++ b/app/database/entities.py @@ -0,0 +1,301 @@ +from typing import Optional, Any, Callable +from dataclasses import dataclass +from functools import wraps +from aiohttp import web +from app.database.db import db + +primitives = (bool, str, int, float, type(None)) + + +def is_primitive(obj): + return isinstance(obj, primitives) + + +class ValidationError(Exception): + def __init__(self, message: str, field: str = None, value: Any = None): + self.message = message + self.field = field + self.value = value + super().__init__(self.message) + + def to_json(self): + result = {"message": self.message} + if self.field is not None: + result["field"] = self.field + if self.value is not None: + result["value"] = self.value + return result + + def __str__(self) -> str: + return f"{self.message} {self.field} {self.value}" + + +class EntityCommon(dict): + @classmethod + def _get_route(cls, include_key: bool): + route = f"/db/{cls.__table_name__}" + if include_key: + route += "".join([f"/{{{k}}}" for k in cls.__key_columns__]) + return route + + @classmethod + def _register_route(cls, routes, verb: str, include_key: bool, handler: Callable): + route = cls._get_route(include_key) + + @getattr(routes, verb)(route) + async def _(request): + try: + data = await handler(request) + return web.json_response(data) + except ValidationError as e: + return web.json_response(e.to_json(), status=400) + + @classmethod + def _transform(cls, row: list[Any]): + return {col: value for col, value in zip(cls.__columns__, row)} + + @classmethod + def _transform_rows(cls, rows: list[list[Any]]): + return [cls._transform(row) for row in rows] + + @classmethod + def _validate(cls, fields: list[str], data: dict, allow_missing: bool = False): + result = {} + + if not isinstance(data, dict): + raise ValidationError("Invalid data") + + # Ensure all required fields are present + for field in data: + if field not in fields: + raise ValidationError("Unknown field", field) + + for key in fields: + col = cls.__columns__[key] + if key not in data: + if col.required and not allow_missing: + raise ValidationError("Missing field", key) + else: + # e.g. for updates, we allow missing fields + continue + elif data[key] is None and col.required: + # Dont allow None for required fields + raise ValidationError("Required field", key) + + # Validate data type + value = data[key] + + if value is not None and not is_primitive(value): + raise ValidationError("Invalid value", key, value) + + try: + type = col.type + if value is not None and not isinstance(value, type): + value = type(value) + result[key] = value + except Exception: + raise ValidationError("Invalid value", key, value) + + return result + + @classmethod + def _validate_id(cls, id: dict): + return cls._validate(cls.__key_columns__, id) + + @classmethod + def _validate_data(cls, data: dict): + return cls._validate(cls.__columns__.keys(), data) + + def __setattr__(self, name, value): + if name in self.__columns__: + self[name] = value + super().__setattr__(name, value) + + def __getattr__(self, name): + if name in self: + return self[name] + raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'") + + +class GetEntity(EntityCommon): + @classmethod + def get(cls, top: Optional[int] = None, where: Optional[str] = None): + limit = "" + if top is not None and isinstance(top, int): + limit = f" LIMIT {top}" + result = db.execute( + f"SELECT * FROM {cls.__table_name__}{limit}{f' WHERE {where}' if where else ''}", + ) + + # Map each row in result to an instance of the class + return cls._transform_rows(result) + + @classmethod + def register_route(cls, routes): + async def get_handler(request): + top = request.rel_url.query.get("top", None) + if top is not None: + try: + top = int(top) + except Exception: + raise ValidationError("Invalid top parameter", "top", top) + return cls.get(top) + + cls._register_route(routes, "get", False, get_handler) + + +class GetEntityById(EntityCommon): + @classmethod + def get_by_id(cls, id: dict): + id = cls._validate_id(id) + + result = db.execute( + f"SELECT * FROM {cls.__table_name__} WHERE {cls.__where_clause__}", + *[id[key] for key in cls.__key_columns__], + ) + + return cls._transform_rows(result) + + @classmethod + def register_route(cls, routes): + async def get_by_id_handler(request): + id = {key: request.match_info.get(key, None) for key in cls.__key_columns__} + return cls.get_by_id(id) + + cls._register_route(routes, "get", True, get_by_id_handler) + + +class CreateEntity(EntityCommon): + @classmethod + def create(cls, data: dict, allow_upsert: bool = False): + data = cls._validate_data(data) + values = ", ".join(["?"] * len(data)) + on_conflict = "" + + data_keys = ", ".join(list(data.keys())) + if allow_upsert: + # Remove key columns from data + upsert_keys = [key for key in data if key not in cls.__key_columns__] + + set_clause = ", ".join([f"{k} = excluded.{k}" for k in upsert_keys]) + on_conflict = f" ON CONFLICT ({', '.join(cls.__key_columns__)}) DO UPDATE SET {set_clause}" + sql = f"INSERT INTO {cls.__table_name__} ({data_keys}) VALUES ({values}){on_conflict} RETURNING *" + result = db.execute( + sql, + *[data[key] for key in data], + ) + + if len(result) == 0: + raise RuntimeError("Failed to create entity") + + return cls._transform_rows(result)[0] + + @classmethod + def register_route(cls, routes): + async def create_handler(request): + data = await request.json() + return cls.create(data) + + cls._register_route(routes, "post", False, create_handler) + + +class UpdateEntity(EntityCommon): + @classmethod + def update(cls, id: list, data: dict): + pass + + +class UpsertEntity(CreateEntity): + @classmethod + def upsert(cls, data: dict): + return cls.create(data, allow_upsert=True) + + @classmethod + def register_route(cls, routes): + async def upsert_handler(request): + data = await request.json() + return cls.upsert(data) + + cls._register_route(routes, "put", False, upsert_handler) + + +class DeleteEntity(EntityCommon): + @classmethod + def delete(cls, id: list): + pass + + +class BaseEntity(GetEntity, CreateEntity, UpdateEntity, DeleteEntity, GetEntityById): + pass + + +@dataclass +class Column: + type: Any + required: bool = False + key: bool = False + default: Any = None + + +def column(type_: Any, required: bool = False, key: bool = False, default: Any = None): + return Column(type_, required, key, default) + + +def table(table_name: str): + def decorator(cls): + # Store table name + cls.__table_name__ = table_name + + # Process column definitions + columns: dict[str, Column] = {} + for attr_name, attr_value in cls.__dict__.items(): + if isinstance(attr_value, Column): + columns[attr_name] = attr_value + + # Store columns metadata + cls.__columns__ = columns + cls.__key_columns__ = [col for col in columns if columns[col].key] + cls.__column_csv__ = ", ".join([col for col in columns]) + cls.__where_clause__ = " AND ".join( + [f"{col} = ?" for col in cls.__key_columns__] + ) + + # Add initialization + original_init = cls.__init__ + + @wraps(original_init) + def new_init(self, *args, **kwargs): + # Initialize columns with default values + for col_name, col_def in cls.__columns__.items(): + setattr(self, col_name, col_def.default) + # Call original init + original_init(self, *args, **kwargs) + + cls.__init__ = new_init + return cls + + return decorator + + +def test(): + @table("models") + class Model(BaseEntity): + id: int = column(int, required=True, key=True) + path: str = column(str, required=True) + name: str = column(str, required=True) + description: Optional[str] = column(str) + architecture: Optional[str] = column(str) + type: str = column(str, required=True) + hash: Optional[str] = column(str) + source_url: Optional[str] = column(str) + + return Model + + +@table("test") +class Test(GetEntity, CreateEntity): + id: int = column(int, required=True, key=True) + test: str = column(str, required=True) + + +Model = test() diff --git a/app/database/routes.py b/app/database/routes.py new file mode 100644 index 00000000..eb2c14b9 --- /dev/null +++ b/app/database/routes.py @@ -0,0 +1,32 @@ +from app.database.db import db +from aiohttp import web + +def create_routes( + routes, prefix, entity, get=False, get_by_id=False, post=False, delete=False +): + if get: + @routes.get(f"/{prefix}/{table}") + async def get_table(request): + connection = db.get_connection() + cursor = connection.cursor() + cursor.execute(f"SELECT * FROM {table}") + rows = cursor.fetchall() + return web.json_response(rows) + + if get_by_id: + @routes.get(f"/{prefix}/{table}/{id}") + async def get_table_by_id(request): + connection = db.get_connection() + cursor = connection.cursor() + cursor.execute(f"SELECT * FROM {table} WHERE id = {id}") + row = cursor.fetchone() + return web.json_response(row) + + if post: + @routes.post(f"/{prefix}/{table}") + async def post_table(request): + data = await request.json() + connection = db.get_connection() + cursor = connection.cursor() + cursor.execute(f"INSERT INTO {table} ({data}) VALUES ({data})") + return web.json_response({"status": "success"}) diff --git a/app/database_updater.py b/app/database/updater.py similarity index 56% rename from app/database_updater.py rename to app/database/updater.py index 58acb9d6..2194a048 100644 --- a/app/database_updater.py +++ b/app/database/updater.py @@ -1,13 +1,16 @@ import logging import os +import sqlite3 +from app.database.versions.v1 import v1 class DatabaseUpdater: - def __init__(self, connection): + def __init__(self, connection, database_path): self.connection = connection + self.database_path = database_path self.current_version = self.get_db_version() self.version_updates = { - 1: self._update_to_v1, + 1: v1, } self.max_version = max(self.version_updates.keys()) self.update_required = self.current_version < self.max_version @@ -16,16 +19,35 @@ class DatabaseUpdater: def get_db_version(self): return self.connection.execute("PRAGMA user_version").fetchone()[0] + def backup(self): + bkp_path = self.database_path + ".bkp" + if os.path.exists(bkp_path): + # TODO: auto-rollback failed upgrades + raise Exception( + f"Database backup already exists, this indicates that a previous upgrade failed. Please restore this backup before continuing. Backup location: {bkp_path}" + ) + + bkp = sqlite3.connect(bkp_path) + self.connection.backup(bkp) + bkp.close() + logging.info("Database backup taken pre-upgrade.") + return bkp_path + def update(self): if not self.update_required: - return + return None + + bkp_version = self.current_version + bkp_path = None + if self.current_version > 0: + bkp_path = self.backup() logging.info(f"Updating database: {self.current_version} -> {self.max_version}") dirname = os.path.dirname(__file__) cursor = self.connection.cursor() for version in range(self.current_version + 1, self.max_version + 1): - filename = os.path.join(dirname, f"db/v{version}.sql") + filename = os.path.join(dirname, f"versions/v{version}.sql") if not os.path.exists(filename): raise Exception( f"Database update script for version {version} not found" @@ -49,6 +71,9 @@ class DatabaseUpdater: cursor.close() self.current_version = self.get_db_version() - def _update_to_v1(self, cursor): - # TODO: migrate users and settings - print("Updating to v1") + if bkp_path: + # Keep a copy of the backup in case something goes wrong and we need to rollback + os.rename(bkp_path, self.database_path + f".v{bkp_version}.bkp") + logging.info(f"Upgrade to successful.") + + return (bkp_version, self.current_version) diff --git a/app/database/versions/v1.py b/app/database/versions/v1.py new file mode 100644 index 00000000..ddafc2a9 --- /dev/null +++ b/app/database/versions/v1.py @@ -0,0 +1,17 @@ +from folder_paths import folder_names_and_paths, get_filename_list, get_full_path + + +def v1(cursor): + print("Updating to v1") + for folder_name in folder_names_and_paths.keys(): + if folder_name == "custom_nodes": + continue + + files = get_filename_list(folder_name) + for file in files: + file_path = get_full_path(folder_name, file) + file_without_extension = file.rsplit(".", maxsplit=1)[0] + cursor.execute( + "INSERT INTO models (path, name, type) VALUES (?, ?, ?)", + (file_path, file_without_extension, folder_name), + ) diff --git a/app/db/v1.sql b/app/database/versions/v1.sql similarity index 67% rename from app/db/v1.sql rename to app/database/versions/v1.sql index fd0783a5..fed71c67 100644 --- a/app/db/v1.sql +++ b/app/database/versions/v1.sql @@ -3,9 +3,10 @@ CREATE TABLE IF NOT EXISTS id INTEGER PRIMARY KEY AUTOINCREMENT, path TEXT NOT NULL, name TEXT NOT NULL, - model TEXT NOT NULL, + description TEXT, + architecture TEXT, type TEXT NOT NULL, - hash TEXT NOT NULL, + hash TEXT, source_url TEXT ); @@ -23,3 +24,18 @@ CREATE TABLE IF NOT EXISTS FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE, FOREIGN KEY (tag_id) REFERENCES tags (id) ON DELETE CASCADE ); + +INSERT INTO + tags (name) +VALUES + ('character'), + ('style'), + ('concept'), + ('clothing'), + ('poses'), + ('background'), + ('vehicle'), + ('buildings'), + ('objects'), + ('animal'), + ('action'); diff --git a/app/model_hasher.py b/app/model_hasher.py new file mode 100644 index 00000000..f6a0e283 --- /dev/null +++ b/app/model_hasher.py @@ -0,0 +1,63 @@ +import hashlib +import logging +import threading +import time +from comfy.cli_args import args + + +class ModelHasher: + + def __init__(self): + self._thread = None + self._lock = threading.Lock() + self._model_entity = None + + def start(self, model_entity): + if args.disable_model_hashing: + return + + self._model_entity = model_entity + + if self._thread is None: + # Lock to prevent multiple threads from starting + with self._lock: + if self._thread is None: + self._thread = threading.Thread(target=self._hash_models) + self._thread.daemon = True + self._thread.start() + + def _get_models(self): + models = self._model_entity.get("WHERE hash IS NULL") + return models + + def _hash_model(self, model_path): + h = hashlib.sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(model_path, "rb", buffering=0) as f: + while n := f.readinto(mv): + h.update(mv[:n]) + hash = h.hexdigest() + return hash + + def _hash_models(self): + while True: + models = self._get_models() + + if len(models) == 0: + break + + for model in models: + time.sleep(0) + now = time.time() + logging.info(f"Hashing model {model['path']}") + hash = self._hash_model(model["path"]) + logging.info( + f"Hashed model {model['path']} in {time.time() - now} seconds" + ) + self._model_entity.update((model["id"],), {"hash": hash}) + + self._thread = None + + +model_hasher = ModelHasher() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 672e5d84..02fb3560 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -144,10 +144,12 @@ parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choic parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).") parser.add_argument("--memory-database", default=False, action="store_true", help="Use an in-memory database instead of a file-based one.") +parser.add_argument("--disable-model-hashing", action="store_true", help="Disable model hashing.") # The default built-in provider hosted under web/ DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" + parser.add_argument( "--front-end-version", type=str, diff --git a/server.py b/server.py index bae898ef..10be4b1d 100644 --- a/server.py +++ b/server.py @@ -34,6 +34,9 @@ from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from typing import Optional from api_server.routes.internal.internal_routes import InternalRoutes +from app.database.entities import get_entity, init_entities +from app.database.db import db +from app.model_hasher import model_hasher class BinaryEventTypes: PREVIEW_IMAGE = 1 @@ -682,11 +685,25 @@ class PromptServer(): timeout = aiohttp.ClientTimeout(total=None) # no timeout self.client_session = aiohttp.ClientSession(timeout=timeout) + def init_db(self, routes): + init_entities(routes) + models = get_entity("models") + + if db.exists: + model_hasher.start(models) + else: + def on_db_update(_, __): + model_hasher.start(models) + + db.register_after_update_callback(on_db_update) + + def add_routes(self): self.user_manager.add_routes(self.routes) self.model_file_manager.add_routes(self.routes) self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items()) self.app.add_subapp('/internal', self.internal_routes.get_app()) + self.init_db(self.routes) # Prefix every route with /api for easier matching for delegation. # This is very useful for frontend dev server, which need to forward diff --git a/tests-unit/app_test/entities_test.py b/tests-unit/app_test/entities_test.py new file mode 100644 index 00000000..d0eb5945 --- /dev/null +++ b/tests-unit/app_test/entities_test.py @@ -0,0 +1,401 @@ +from comfy.cli_args import args + +args.memory_database = True # force in-memory database for testing + +from typing import Callable, Optional +import pytest +import pytest_asyncio +from unittest.mock import patch +from aiohttp import web +from app.database.entities import ( + column, + table, + Column, + GetEntity, + GetEntityById, + CreateEntity, + UpsertEntity, +) +from app.database.db import db + +pytestmark = pytest.mark.asyncio + + +def create_table(entity): + # reset db + db.close() + + cols: list[Column] = entity.__columns__ + # Create tables as temporary so when we close the db, the tables are dropped for next test + sql = f"CREATE TEMPORARY TABLE {entity.__table_name__} ( " + for col_name, col in cols.items(): + type = None + if col.type == int: + type = "INTEGER" + elif col.type == str: + type = "TEXT" + + sql += f"{col_name} {type}" + if col.required: + sql += " NOT NULL" + sql += ", " + + sql += f"PRIMARY KEY ({', '.join(entity.__key_columns__)})" + sql += ")" + db.execute(sql) + + +async def wrap_db(method: Callable, expected_sql: str, expected_args: list): + with patch.object(db, "execute", wraps=db.execute) as mock: + response = await method() + assert mock.call_args[0][0] == expected_sql + assert mock.call_args[0][1:] == expected_args + return response + + +@pytest.fixture +def getable_entity(): + @table("getable_entity") + class GetableEntity(GetEntity): + id: int = column(int, required=True, key=True) + test: str = column(str, required=True) + nullable: Optional[str] = column(str) + + return GetableEntity + + +@pytest.fixture +def getable_by_id_entity(): + @table("getable_by_id_entity") + class GetableByIdEntity(GetEntityById): + id: int = column(int, required=True, key=True) + test: str = column(str, required=True) + + return GetableByIdEntity + + +@pytest.fixture +def getable_by_id_composite_entity(): + @table("getable_by_id_composite_entity") + class GetableByIdCompositeEntity(GetEntityById): + id1: str = column(str, required=True, key=True) + id2: int = column(int, required=True, key=True) + test: str = column(str, required=True) + + return GetableByIdCompositeEntity + + +@pytest.fixture +def creatable_entity(): + @table("creatable_entity") + class CreatableEntity(CreateEntity): + id: int = column(int, required=True, key=True) + test: str = column(str, required=True) + reqd: str = column(str, required=True) + nullable: Optional[str] = column(str) + + return CreatableEntity + + +@pytest.fixture +def upsertable_entity(): + @table("upsertable_entity") + class UpsertableEntity(UpsertEntity): + id: int = column(int, required=True, key=True) + test: str = column(str, required=True) + reqd: str = column(str, required=True) + nullable: Optional[str] = column(str) + + return UpsertableEntity + + +@pytest.fixture() +def entity(request): + value = request.getfixturevalue(request.param) + create_table(value) + return value + + +@pytest_asyncio.fixture +async def client(aiohttp_client, app): + return await aiohttp_client(app) + + +@pytest.fixture +def app(entity): + app = web.Application() + routes = web.RouteTableDef() + entity.register_route(routes) + app.add_routes(routes) + return app + + +@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True) +async def test_get_model_empty_response(client): + expected_sql = "SELECT * FROM getable_entity" + expected_args = () + response = await wrap_db( + lambda: client.get("/db/getable_entity"), expected_sql, expected_args + ) + + assert response.status == 200 + assert await response.json() == [] + + +@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True) +async def test_get_model_with_data(client): + # seed db + db.execute( + "INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2')" + ) + + expected_sql = "SELECT * FROM getable_entity" + expected_args = () + response = await wrap_db( + lambda: client.get("/db/getable_entity"), expected_sql, expected_args + ) + + assert response.status == 200 + assert await response.json() == [ + {"id": 1, "test": "test1", "nullable": None}, + {"id": 2, "test": "test2", "nullable": "test2"}, + ] + + +@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True) +async def test_get_model_with_top_parameter(client): + # seed with 3 rows + db.execute( + "INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2'), (3, 'test3', 'test3')" + ) + + expected_sql = "SELECT * FROM getable_entity LIMIT 2" + expected_args = () + response = await wrap_db( + lambda: client.get("/db/getable_entity?top=2"), + expected_sql, + expected_args, + ) + + assert response.status == 200 + assert await response.json() == [ + {"id": 1, "test": "test1", "nullable": None}, + {"id": 2, "test": "test2", "nullable": "test2"}, + ] + + +@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True) +async def test_get_model_with_invalid_top_parameter(client): + response = await client.get("/db/getable_entity?top=hello") + assert response.status == 400 + assert await response.json() == { + "message": "Invalid top parameter", + "field": "top", + "value": "hello", + } + + +@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True) +async def test_get_model_by_id_empty_response(client): + # seed db + db.execute("INSERT INTO getable_by_id_entity (id, test) VALUES (1, 'test1')") + + expected_sql = "SELECT * FROM getable_by_id_entity WHERE id = ?" + expected_args = (1,) + response = await wrap_db( + lambda: client.get("/db/getable_by_id_entity/1"), + expected_sql, + expected_args, + ) + + assert response.status == 200 + assert await response.json() == [ + {"id": 1, "test": "test1"}, + ] + + +@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True) +async def test_get_model_by_id_with_invalid_id(client): + response = await client.get("/db/getable_by_id_entity/hello") + assert response.status == 400 + assert await response.json() == { + "message": "Invalid value", + "field": "id", + "value": "hello", + } + + +@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True) +async def test_get_model_by_id_composite(client): + # seed db + db.execute( + "INSERT INTO getable_by_id_composite_entity (id1, id2, test) VALUES ('one', 2, 'test')" + ) + + expected_sql = ( + "SELECT * FROM getable_by_id_composite_entity WHERE id1 = ? AND id2 = ?" + ) + expected_args = ("one", 2) + response = await wrap_db( + lambda: client.get("/db/getable_by_id_composite_entity/one/2"), + expected_sql, + expected_args, + ) + + assert response.status == 200 + assert await response.json() == [ + {"id1": "one", "id2": 2, "test": "test"}, + ] + + +@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True) +async def test_get_model_by_id_composite_with_invalid_id(client): + response = await client.get("/db/getable_by_id_composite_entity/hello/hello") + assert response.status == 400 + assert await response.json() == { + "message": "Invalid value", + "field": "id2", + "value": "hello", + } + + +@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True) +async def test_create_model(client): + expected_sql = ( + "INSERT INTO creatable_entity (id, test, reqd) VALUES (?, ?, ?) RETURNING *" + ) + expected_args = (1, "test1", "reqd1") + response = await wrap_db( + lambda: client.post( + "/db/creatable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"} + ), + expected_sql, + expected_args, + ) + + assert response.status == 200 + assert await response.json() == { + "id": 1, + "test": "test1", + "reqd": "reqd1", + "nullable": None, + } + + +@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True) +async def test_create_model_missing_required_field(client): + response = await client.post( + "/db/creatable_entity", json={"id": 1, "test": "test1"} + ) + + assert response.status == 400 + assert await response.json() == { + "message": "Missing field", + "field": "reqd", + } + + +@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True) +async def test_create_model_missing_key_field(client): + response = await client.post( + "/db/creatable_entity", + json={"test": "test1", "reqd": "reqd1"}, # Missing 'id' which is a key + ) + + assert response.status == 400 + assert await response.json() == { + "message": "Missing field", + "field": "id", + } + + +@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True) +async def test_create_model_invalid_key_data(client): + response = await client.post( + "/db/creatable_entity", + json={ + "id": "not_an_integer", + "test": "test1", + "reqd": "reqd1", + }, # id should be int + ) + + assert response.status == 400 + assert await response.json() == { + "message": "Invalid value", + "field": "id", + "value": "not_an_integer", + } + + +@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True) +async def test_create_model_invalid_field_data(client): + response = await client.post( + "/db/creatable_entity", + json={"id": "aaa", "test": "123", "reqd": "reqd1"}, # id should be int + ) + + assert response.status == 400 + assert await response.json() == { + "message": "Invalid value", + "field": "id", + "value": "aaa", + } + + +@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True) +async def test_create_model_invalid_field_type(client): + response = await client.post( + "/db/creatable_entity", + json={ + "id": 1, + "test": ["invalid_array"], + "reqd": "reqd1", + }, # test should be string + ) + + assert response.status == 400 + assert await response.json() == { + "message": "Invalid value", + "field": "test", + "value": ["invalid_array"], + } + + +@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True) +async def test_create_model_invalid_field_name(client): + response = await client.post( + "/db/creatable_entity", + json={"id": 1, "test": "test1", "reqd": "reqd1", "nonexistent_field": "value"}, + ) + + assert response.status == 400 + assert await response.json() == { + "message": "Unknown field", + "field": "nonexistent_field", + } + + +@pytest.mark.parametrize("entity", ["upsertable_entity"], indirect=True) +async def test_upsert_model(client): + expected_sql = ( + "INSERT INTO upsertable_entity (id, test, reqd) VALUES (?, ?, ?) " + "ON CONFLICT (id) DO UPDATE SET test = excluded.test, reqd = excluded.reqd " + "RETURNING *" + ) + expected_args = (1, "test1", "reqd1") + response = await wrap_db( + lambda: client.put( + "/db/upsertable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"} + ), + expected_sql, + expected_args, + ) + + assert response.status == 200 + assert await response.json() == { + "id": 1, + "test": "test1", + "reqd": "reqd1", + "nullable": None, + }