Compare commits
82 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a616b81c1 | ||
|
|
3bed56bb13 | ||
|
|
4e402b11c6 | ||
|
|
48272448ad | ||
|
|
f7695b5f9e | ||
|
|
452179fe4f | ||
|
|
bf9a90a145 | ||
|
|
c1b92b719d | ||
|
|
cdc3b97dd5 | ||
|
|
8d4e06324f | ||
|
|
57e8bf6a9f | ||
|
|
0ee322ec5f | ||
|
|
79d5ceae6e | ||
|
|
2d5b3e0078 | ||
|
|
8e4118c0de | ||
|
|
3fc6ebcdd7 | ||
|
|
20a560eb97 | ||
|
|
82c5308561 | ||
|
|
26fb2c68e8 | ||
|
|
bf2650a80e | ||
|
|
53646e0f32 | ||
|
|
20879c78f9 | ||
|
|
b666539595 | ||
|
|
95d8713482 | ||
|
|
0d4e29f13f | ||
|
|
497db6212f | ||
|
|
24dc581dc3 | ||
|
|
4c82741b54 | ||
|
|
15c39ea757 | ||
|
|
b7143b74ce | ||
|
|
61196d8857 | ||
|
|
b4526d3fc3 | ||
|
|
3d802710e7 | ||
|
|
7126ecffde | ||
|
|
ab885b33ba | ||
|
|
839ed3368e | ||
|
|
6e8cdcd3cb | ||
|
|
e5c3f4b87f | ||
|
|
bc6be6c11e | ||
|
|
94323a26a7 | ||
|
|
5818f6cf51 | ||
|
|
0b734de449 | ||
|
|
5e16f1d24b | ||
|
|
2fd9c1308a | ||
|
|
8f0009aad0 | ||
|
|
41444b5236 | ||
|
|
772e620e32 | ||
|
|
07f6eeaa13 | ||
|
|
22535d0589 | ||
|
|
898615122f | ||
|
|
156a28786b | ||
|
|
f498d855ba | ||
|
|
b699a15062 | ||
|
|
9cc90ee3eb | ||
|
|
9a0a5d32ee | ||
|
|
d9f90965c8 | ||
|
|
41886af138 | ||
|
|
22a1d7ce78 | ||
|
|
4ac401af2b | ||
|
|
5fb59c8475 | ||
|
|
122c9ca1ce | ||
|
|
3b9a6cf2b1 | ||
|
|
3748e7ef7a | ||
|
|
8ebf2d8831 | ||
|
|
a72d152b0c | ||
|
|
eb476e6ea9 | ||
|
|
2d28b0b479 | ||
|
|
8b275ce5be | ||
|
|
2a18e98ccf | ||
|
|
8a5281006f | ||
|
|
bdeb1c171c | ||
|
|
9c1ed58ef2 | ||
|
|
8b90e50979 | ||
|
|
6ee066a14f | ||
|
|
dd5b57e3d7 | ||
|
|
75a818c720 | ||
|
|
2865f913f7 | ||
|
|
b49616f951 | ||
|
|
5e29e7a488 | ||
|
|
8afb97cd3f | ||
|
|
69694f40b3 | ||
|
|
c49025f01b |
2
.github/workflows/test-launch.yml
vendored
2
.github/workflows/test-launch.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
- name: Start ComfyUI server
|
||||
run: |
|
||||
python main.py --cpu 2>&1 | tee console_output.log &
|
||||
wait-for-it --service 127.0.0.1:8188 -t 600
|
||||
wait-for-it --service 127.0.0.1:8188 -t 30
|
||||
working-directory: ComfyUI
|
||||
- name: Check for unhandled exceptions in server log
|
||||
run: |
|
||||
|
||||
71
README.md
71
README.md
@@ -28,7 +28,7 @@
|
||||
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
||||
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
|
||||

|
||||

|
||||
</div>
|
||||
|
||||
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
||||
@@ -39,6 +39,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- Asynchronous Queue system
|
||||
@@ -74,37 +75,37 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
| Keybind | Explanation |
|
||||
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||
| Ctrl + Enter | Queue up current graph for generation |
|
||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||
| Ctrl + Alt + Enter | Cancel current generation |
|
||||
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
||||
| Ctrl + S | Save workflow |
|
||||
| Ctrl + O | Load workflow |
|
||||
| Ctrl + A | Select all nodes |
|
||||
| Alt + C | Collapse/uncollapse selected nodes |
|
||||
| Ctrl + M | Mute/unmute selected nodes |
|
||||
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||
| Delete/Backspace | Delete selected nodes |
|
||||
| Ctrl + Backspace | Delete the current graph |
|
||||
| Space | Move the canvas around when held and moving the cursor |
|
||||
| Ctrl/Shift + Click | Add clicked node to selection |
|
||||
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||
| Shift + Drag | Move multiple selected nodes at the same time |
|
||||
| Ctrl + D | Load default graph |
|
||||
| Alt + `+` | Canvas Zoom in |
|
||||
| Alt + `-` | Canvas Zoom out |
|
||||
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
||||
| P | Pin/Unpin selected nodes |
|
||||
| Ctrl + G | Group selected nodes |
|
||||
| Q | Toggle visibility of the queue |
|
||||
| H | Toggle visibility of history |
|
||||
| R | Refresh graph |
|
||||
| `Ctrl` + `Enter` | Queue up current graph for generation |
|
||||
| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
|
||||
| `Ctrl` + `Alt` + `Enter` | Cancel current generation |
|
||||
| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
|
||||
| `Ctrl` + `S` | Save workflow |
|
||||
| `Ctrl` + `O` | Load workflow |
|
||||
| `Ctrl` + `A` | Select all nodes |
|
||||
| `Alt `+ `C` | Collapse/uncollapse selected nodes |
|
||||
| `Ctrl` + `M` | Mute/unmute selected nodes |
|
||||
| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||
| `Delete`/`Backspace` | Delete selected nodes |
|
||||
| `Ctrl` + `Backspace` | Delete the current graph |
|
||||
| `Space` | Move the canvas around when held and moving the cursor |
|
||||
| `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
|
||||
| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||
| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||
| `Shift` + `Drag` | Move multiple selected nodes at the same time |
|
||||
| `Ctrl` + `D` | Load default graph |
|
||||
| `Alt` + `+` | Canvas Zoom in |
|
||||
| `Alt` + `-` | Canvas Zoom out |
|
||||
| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
|
||||
| `P` | Pin/Unpin selected nodes |
|
||||
| `Ctrl` + `G` | Group selected nodes |
|
||||
| `Q` | Toggle visibility of the queue |
|
||||
| `H` | Toggle visibility of history |
|
||||
| `R` | Refresh graph |
|
||||
| Double-Click LMB | Open node quick search palette |
|
||||
| Shift + Drag | Move multiple wires at once |
|
||||
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
|
||||
| `Shift` + Drag | Move multiple wires at once |
|
||||
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
|
||||
|
||||
Ctrl can also be replaced with Cmd instead for macOS users
|
||||
`Ctrl` can also be replaced with `Cmd` instead for macOS users
|
||||
|
||||
# Installing
|
||||
|
||||
@@ -140,7 +141,7 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
||||
|
||||
@@ -212,6 +213,14 @@ For 6700, 6600 and maybe other RDNA2 or older: ```HSA_OVERRIDE_GFX_VERSION=10.3.
|
||||
|
||||
For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py```
|
||||
|
||||
### AMD ROCm Tips
|
||||
|
||||
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
|
||||
|
||||
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
|
||||
|
||||
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
|
||||
|
||||
# Notes
|
||||
|
||||
Only parts of the graph that have an output with all the correct inputs will be executed.
|
||||
|
||||
@@ -2,6 +2,7 @@ from aiohttp import web
|
||||
from typing import Optional
|
||||
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
||||
from api_server.services.file_service import FileService
|
||||
from api_server.services.terminal_service import TerminalService
|
||||
import app.logger
|
||||
|
||||
class InternalRoutes:
|
||||
@@ -9,9 +10,9 @@ class InternalRoutes:
|
||||
The top level web router for internal routes: /internal/*
|
||||
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
|
||||
Check README.md for more information.
|
||||
|
||||
'''
|
||||
def __init__(self):
|
||||
|
||||
def __init__(self, prompt_server):
|
||||
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||
self._app: Optional[web.Application] = None
|
||||
self.file_service = FileService({
|
||||
@@ -19,6 +20,8 @@ class InternalRoutes:
|
||||
"user": user_directory,
|
||||
"output": output_directory
|
||||
})
|
||||
self.prompt_server = prompt_server
|
||||
self.terminal_service = TerminalService(prompt_server)
|
||||
|
||||
def setup_routes(self):
|
||||
@self.routes.get('/files')
|
||||
@@ -34,7 +37,28 @@ class InternalRoutes:
|
||||
|
||||
@self.routes.get('/logs')
|
||||
async def get_logs(request):
|
||||
return web.json_response(app.logger.get_logs())
|
||||
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
||||
|
||||
@self.routes.get('/logs/raw')
|
||||
async def get_logs(request):
|
||||
self.terminal_service.update_size()
|
||||
return web.json_response({
|
||||
"entries": list(app.logger.get_logs()),
|
||||
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
|
||||
})
|
||||
|
||||
@self.routes.patch('/logs/subscribe')
|
||||
async def subscribe_logs(request):
|
||||
json_data = await request.json()
|
||||
client_id = json_data["clientId"]
|
||||
enabled = json_data["enabled"]
|
||||
if enabled:
|
||||
self.terminal_service.subscribe(client_id)
|
||||
else:
|
||||
self.terminal_service.unsubscribe(client_id)
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
|
||||
@self.routes.get('/folder_paths')
|
||||
async def get_folder_paths(request):
|
||||
|
||||
60
api_server/services/terminal_service.py
Normal file
60
api_server/services/terminal_service.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from app.logger import on_flush
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
class TerminalService:
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
self.cols = None
|
||||
self.rows = None
|
||||
self.subscriptions = set()
|
||||
on_flush(self.send_messages)
|
||||
|
||||
def get_terminal_size(self):
|
||||
try:
|
||||
size = os.get_terminal_size()
|
||||
return (size.columns, size.lines)
|
||||
except OSError:
|
||||
try:
|
||||
size = shutil.get_terminal_size()
|
||||
return (size.columns, size.lines)
|
||||
except OSError:
|
||||
return (80, 24) # fallback to 80x24
|
||||
|
||||
def update_size(self):
|
||||
columns, lines = self.get_terminal_size()
|
||||
changed = False
|
||||
|
||||
if columns != self.cols:
|
||||
self.cols = columns
|
||||
changed = True
|
||||
|
||||
if lines != self.rows:
|
||||
self.rows = lines
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
return {"cols": self.cols, "rows": self.rows}
|
||||
|
||||
return None
|
||||
|
||||
def subscribe(self, client_id):
|
||||
self.subscriptions.add(client_id)
|
||||
|
||||
def unsubscribe(self, client_id):
|
||||
self.subscriptions.discard(client_id)
|
||||
|
||||
def send_messages(self, entries):
|
||||
if not len(entries) or not len(self.subscriptions):
|
||||
return
|
||||
|
||||
new_size = self.update_size()
|
||||
|
||||
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
|
||||
if client_id not in self.server.sockets:
|
||||
# Automatically unsub if the socket has disconnected
|
||||
self.unsubscribe(client_id)
|
||||
continue
|
||||
|
||||
self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)
|
||||
@@ -1,20 +1,69 @@
|
||||
import logging
|
||||
from logging.handlers import MemoryHandler
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
|
||||
logs = None
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
stdout_interceptor = None
|
||||
stderr_interceptor = None
|
||||
|
||||
|
||||
class LogInterceptor(io.TextIOWrapper):
|
||||
def __init__(self, stream, *args, **kwargs):
|
||||
buffer = stream.buffer
|
||||
encoding = stream.encoding
|
||||
super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
|
||||
self._lock = threading.Lock()
|
||||
self._flush_callbacks = []
|
||||
self._logs_since_flush = []
|
||||
|
||||
def write(self, data):
|
||||
entry = {"t": datetime.now().isoformat(), "m": data}
|
||||
with self._lock:
|
||||
self._logs_since_flush.append(entry)
|
||||
|
||||
# Simple handling for cr to overwrite the last output if it isnt a full line
|
||||
# else logs just get full of progress messages
|
||||
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
|
||||
logs.pop()
|
||||
logs.append(entry)
|
||||
super().write(data)
|
||||
|
||||
def flush(self):
|
||||
super().flush()
|
||||
for cb in self._flush_callbacks:
|
||||
cb(self._logs_since_flush)
|
||||
self._logs_since_flush = []
|
||||
|
||||
def on_flush(self, callback):
|
||||
self._flush_callbacks.append(callback)
|
||||
|
||||
|
||||
def get_logs():
|
||||
return "\n".join([formatter.format(x) for x in logs])
|
||||
return logs
|
||||
|
||||
|
||||
def on_flush(callback):
|
||||
if stdout_interceptor is not None:
|
||||
stdout_interceptor.on_flush(callback)
|
||||
if stderr_interceptor is not None:
|
||||
stderr_interceptor.on_flush(callback)
|
||||
|
||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||
global logs
|
||||
if logs:
|
||||
return
|
||||
|
||||
# Override output streams and log to buffer
|
||||
logs = deque(maxlen=capacity)
|
||||
|
||||
global stdout_interceptor
|
||||
global stderr_interceptor
|
||||
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
|
||||
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
|
||||
|
||||
# Setup default global logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(log_level)
|
||||
@@ -22,10 +71,3 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
# Create a memory handler with a deque as its buffer
|
||||
logs = deque(maxlen=capacity)
|
||||
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
|
||||
memory_handler.buffer = logs
|
||||
memory_handler.setFormatter(formatter)
|
||||
logger.addHandler(memory_handler)
|
||||
|
||||
@@ -1,25 +1,42 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import glob
|
||||
import shutil
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
import folder_paths
|
||||
from .app_settings import AppSettings
|
||||
from typing import TypedDict
|
||||
|
||||
default_user = "default"
|
||||
|
||||
|
||||
class FileInfo(TypedDict):
|
||||
path: str
|
||||
size: int
|
||||
modified: int
|
||||
|
||||
|
||||
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
||||
return {
|
||||
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
||||
"size": os.path.getsize(path),
|
||||
"modified": os.path.getmtime(path)
|
||||
}
|
||||
|
||||
|
||||
class UserManager():
|
||||
def __init__(self):
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
||||
self.settings = AppSettings(self)
|
||||
if not os.path.exists(user_directory):
|
||||
os.mkdir(user_directory)
|
||||
os.makedirs(user_directory, exist_ok=True)
|
||||
if not args.multi_user:
|
||||
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||
@@ -154,6 +171,7 @@ class UserManager():
|
||||
|
||||
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
||||
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
|
||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||
|
||||
# Use different patterns based on whether we're recursing or not
|
||||
if recurse:
|
||||
@@ -161,26 +179,21 @@ class UserManager():
|
||||
else:
|
||||
pattern = os.path.join(glob.escape(path), '*')
|
||||
|
||||
results = glob.glob(pattern, recursive=recurse)
|
||||
def process_full_path(full_path: str) -> FileInfo | str | list[str]:
|
||||
if full_info:
|
||||
return get_file_info(full_path, path)
|
||||
|
||||
if full_info:
|
||||
results = [
|
||||
{
|
||||
'path': os.path.relpath(x, path).replace(os.sep, '/'),
|
||||
'size': os.path.getsize(x),
|
||||
'modified': os.path.getmtime(x)
|
||||
} for x in results if os.path.isfile(x)
|
||||
]
|
||||
else:
|
||||
results = [
|
||||
os.path.relpath(x, path).replace(os.sep, '/')
|
||||
for x in results
|
||||
if os.path.isfile(x)
|
||||
]
|
||||
rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
|
||||
if split_path:
|
||||
return [rel_path] + rel_path.split('/')
|
||||
|
||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||
if split_path and not full_info:
|
||||
results = [[x] + x.split('/') for x in results]
|
||||
return rel_path
|
||||
|
||||
results = [
|
||||
process_full_path(full_path)
|
||||
for full_path in glob.glob(pattern, recursive=recurse)
|
||||
if os.path.isfile(full_path)
|
||||
]
|
||||
|
||||
return web.json_response(results)
|
||||
|
||||
@@ -208,20 +221,51 @@ class UserManager():
|
||||
|
||||
@routes.post("/userdata/{file}")
|
||||
async def post_userdata(request):
|
||||
"""
|
||||
Upload or update a user data file.
|
||||
|
||||
This endpoint handles file uploads to a user's data directory, with options for
|
||||
controlling overwrite behavior and response format.
|
||||
|
||||
Query Parameters:
|
||||
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
||||
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
||||
If "false", returns only the relative file path.
|
||||
|
||||
Path Parameters:
|
||||
- file: The target file path (URL encoded if necessary).
|
||||
|
||||
Returns:
|
||||
- 400: If 'file' parameter is missing.
|
||||
- 403: If the requested path is not allowed.
|
||||
- 409: If overwrite=false and the file already exists.
|
||||
- 200: JSON response with either:
|
||||
- Full file information (if full_info=true)
|
||||
- Relative file path (if full_info=false)
|
||||
|
||||
The request body should contain the raw file content to be written.
|
||||
"""
|
||||
path = get_user_data_path(request)
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
full_info = request.query.get('full_info', 'false').lower() == "true"
|
||||
|
||||
if not overwrite and os.path.exists(path):
|
||||
return web.Response(status=409)
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
|
||||
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
resp = get_file_info(path, user_path)
|
||||
else:
|
||||
resp = os.path.relpath(path, user_path)
|
||||
|
||||
return web.json_response(resp)
|
||||
|
||||
@routes.delete("/userdata/{file}")
|
||||
@@ -236,6 +280,30 @@ class UserManager():
|
||||
|
||||
@routes.post("/userdata/{file}/move/{dest}")
|
||||
async def move_userdata(request):
|
||||
"""
|
||||
Move or rename a user data file.
|
||||
|
||||
This endpoint handles moving or renaming files within a user's data directory, with options for
|
||||
controlling overwrite behavior and response format.
|
||||
|
||||
Path Parameters:
|
||||
- file: The source file path (URL encoded if necessary)
|
||||
- dest: The destination file path (URL encoded if necessary)
|
||||
|
||||
Query Parameters:
|
||||
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
||||
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
||||
If "false", returns only the relative file path.
|
||||
|
||||
Returns:
|
||||
- 400: If either 'file' or 'dest' parameter is missing
|
||||
- 403: If either requested path is not allowed
|
||||
- 404: If the source file does not exist
|
||||
- 409: If overwrite=false and the destination file already exists
|
||||
- 200: JSON response with either:
|
||||
- Full file information (if full_info=true)
|
||||
- Relative file path (if full_info=false)
|
||||
"""
|
||||
source = get_user_data_path(request, check_exists=True)
|
||||
if not isinstance(source, str):
|
||||
return source
|
||||
@@ -244,12 +312,19 @@ class UserManager():
|
||||
if not isinstance(source, str):
|
||||
return dest
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
if not overwrite and os.path.exists(dest):
|
||||
return web.Response(status=409)
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
full_info = request.query.get('full_info', 'false').lower() == "true"
|
||||
|
||||
print(f"moving '{source}' -> '{dest}'")
|
||||
if not overwrite and os.path.exists(dest):
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
logging.info(f"moving '{source}' -> '{dest}'")
|
||||
shutil.move(source, dest)
|
||||
|
||||
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
resp = get_file_info(dest, user_path)
|
||||
else:
|
||||
resp = os.path.relpath(dest, user_path)
|
||||
|
||||
return web.json_response(resp)
|
||||
|
||||
122
comfy/cldm/dit_embedder.py
Normal file
122
comfy/cldm/dit_embedder.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
|
||||
|
||||
|
||||
class ControlNetEmbedder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
attention_head_dim: int,
|
||||
num_attention_heads: int,
|
||||
adm_in_channels: int,
|
||||
num_layers: int,
|
||||
main_model_double: int,
|
||||
double_y_emb: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
pos_embed_max_size: Optional[int] = None,
|
||||
operations = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.main_model_double = main_model_double
|
||||
self.dtype = dtype
|
||||
self.hidden_size = num_attention_heads * attention_head_dim
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=self.hidden_size,
|
||||
strict_img_size=pos_embed_max_size is None,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.double_y_emb = double_y_emb
|
||||
if self.double_y_emb:
|
||||
self.orig_y_embedder = VectorEmbedder(
|
||||
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||||
)
|
||||
self.y_embedder = VectorEmbedder(
|
||||
self.hidden_size, self.hidden_size, dtype, device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.y_embedder = VectorEmbedder(
|
||||
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
DismantledBlock(
|
||||
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
)
|
||||
|
||||
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
|
||||
# TODO double check this logic when 8b
|
||||
self.use_y_embedder = True
|
||||
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
|
||||
self.pos_embed_input = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=self.hidden_size,
|
||||
strict_img_size=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
hint = None,
|
||||
) -> Tuple[Tensor, List[Tensor]]:
|
||||
x_shape = list(x.shape)
|
||||
x = self.x_embedder(x)
|
||||
if not self.double_y_emb:
|
||||
h = (x_shape[-2] + 1) // self.patch_size
|
||||
w = (x_shape[-1] + 1) // self.patch_size
|
||||
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
|
||||
c = self.t_embedder(timesteps, dtype=x.dtype)
|
||||
if y is not None and self.y_embedder is not None:
|
||||
if self.double_y_emb:
|
||||
y = self.orig_y_embedder(y)
|
||||
y = self.y_embedder(y)
|
||||
c = c + y
|
||||
|
||||
x = x + self.pos_embed_input(hint)
|
||||
|
||||
block_out = ()
|
||||
|
||||
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
|
||||
for i in range(len(self.transformer_blocks)):
|
||||
out = self.transformer_blocks[i](x, c)
|
||||
if not self.double_y_emb:
|
||||
x = out
|
||||
block_out += (self.controlnet_blocks[i](out),) * repeat
|
||||
|
||||
return {"output": block_out}
|
||||
@@ -60,8 +60,10 @@ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If
|
||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||
|
||||
fpunet_group = parser.add_mutually_exclusive_group()
|
||||
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
|
||||
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
|
||||
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
|
||||
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
|
||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
||||
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module):
|
||||
|
||||
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||
"gelu": torch.nn.functional.gelu,
|
||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||
}
|
||||
|
||||
class CLIPMLP(torch.nn.Module):
|
||||
@@ -139,27 +140,35 @@ class CLIPTextModel(torch.nn.Module):
|
||||
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
if model_type == "siglip_vision_model":
|
||||
self.class_embedding = None
|
||||
patch_bias = True
|
||||
else:
|
||||
num_patches = num_patches + 1
|
||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||
patch_bias = False
|
||||
|
||||
self.patch_embedding = operations.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False,
|
||||
bias=patch_bias,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
num_positions = num_patches + 1
|
||||
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
||||
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||
if self.class_embedding is not None:
|
||||
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
|
||||
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||
|
||||
|
||||
class CLIPVision(torch.nn.Module):
|
||||
@@ -170,9 +179,15 @@ class CLIPVision(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.pre_layrnorm = lambda a: a
|
||||
self.output_layernorm = True
|
||||
else:
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.output_layernorm = False
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.post_layernorm = operations.LayerNorm(embed_dim)
|
||||
|
||||
@@ -181,14 +196,21 @@ class CLIPVision(torch.nn.Module):
|
||||
x = self.pre_layrnorm(x)
|
||||
#TODO: attention_mask?
|
||||
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
if self.output_layernorm:
|
||||
x = self.post_layernorm(x)
|
||||
pooled_output = x
|
||||
else:
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
return x, i, pooled_output
|
||||
|
||||
class CLIPVisionModelProjection(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
|
||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||
if "projection_dim" in config_dict:
|
||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||
else:
|
||||
self.visual_projection = lambda a: a
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.vision_model(*args, **kwargs)
|
||||
|
||||
@@ -16,13 +16,18 @@ class Output:
|
||||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224):
|
||||
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
@@ -35,6 +40,8 @@ class ClipVisionModel():
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
@@ -49,9 +56,9 @@ class ClipVisionModel():
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def encode_image(self, image):
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
|
||||
outputs = Output()
|
||||
@@ -94,7 +101,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
|
||||
13
comfy/clip_vision_siglip_384.json
Normal file
13
comfy/clip_vision_siglip_384.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": 384,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 14,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
||||
43
comfy/comfy_types/README.md
Normal file
43
comfy/comfy_types/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# Comfy Typing
|
||||
## Type hinting for ComfyUI Node development
|
||||
|
||||
This module provides type hinting and concrete convenience types for node developers.
|
||||
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
|
||||
|
||||
```python
|
||||
from comfy_types import IO, ComfyNodeABC, CheckLazyMixin
|
||||
|
||||
class ExampleNode(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {"required": {}}
|
||||
```
|
||||
|
||||
Full example is in [examples/example_nodes.py](examples/example_nodes.py).
|
||||
|
||||
# Types
|
||||
A few primary types are documented below. More complete information is available via the docstrings on each type.
|
||||
|
||||
## `IO`
|
||||
|
||||
A string enum of built-in and a few custom data types. Includes the following special types and their requisite plumbing:
|
||||
|
||||
- `ANY`: `"*"`
|
||||
- `NUMBER`: `"FLOAT,INT"`
|
||||
- `PRIMITIVE`: `"STRING,FLOAT,INT,BOOLEAN"`
|
||||
|
||||
## `ComfyNodeABC`
|
||||
|
||||
An abstract base class for nodes, offering type-hinting / autocomplete, and somewhat-alright docstrings.
|
||||
|
||||
### Type hinting for `INPUT_TYPES`
|
||||
|
||||

|
||||
|
||||
### `INPUT_TYPES` return dict
|
||||
|
||||

|
||||
|
||||
### Options for individual inputs
|
||||
|
||||

|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from typing import Callable, Protocol, TypedDict, Optional, List
|
||||
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
|
||||
|
||||
|
||||
class UnetApplyFunction(Protocol):
|
||||
@@ -30,3 +31,15 @@ class UnetParams(TypedDict):
|
||||
|
||||
|
||||
UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"UnetWrapperFunction",
|
||||
UnetApplyConds.__name__,
|
||||
UnetParams.__name__,
|
||||
UnetApplyFunction.__name__,
|
||||
IO.__name__,
|
||||
InputTypeDict.__name__,
|
||||
ComfyNodeABC.__name__,
|
||||
CheckLazyMixin.__name__,
|
||||
]
|
||||
28
comfy/comfy_types/examples/example_nodes.py
Normal file
28
comfy/comfy_types/examples/example_nodes.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
from inspect import cleandoc
|
||||
|
||||
|
||||
class ExampleNode(ComfyNodeABC):
|
||||
"""An example node that just adds 1 to an input integer.
|
||||
|
||||
* Requires an IDE configured with analysis paths etc to be worth looking at.
|
||||
* Not intended for use in ComfyUI.
|
||||
"""
|
||||
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
CATEGORY = "examples"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"input_int": (IO.INT, {"defaultInput": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.INT,)
|
||||
RETURN_NAMES = ("input_plus_one",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, input_int: int):
|
||||
return (input_int + 1,)
|
||||
BIN
comfy/comfy_types/examples/input_options.png
Normal file
BIN
comfy/comfy_types/examples/input_options.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
BIN
comfy/comfy_types/examples/input_types.png
Normal file
BIN
comfy/comfy_types/examples/input_types.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
BIN
comfy/comfy_types/examples/required_hint.png
Normal file
BIN
comfy/comfy_types/examples/required_hint.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
274
comfy/comfy_types/node_typing.py
Normal file
274
comfy/comfy_types/node_typing.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Comfy-specific type hinting"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Literal, TypedDict
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
"""Base class for string enums. Python's StrEnum is not available until 3.11."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class IO(StrEnum):
|
||||
"""Node input/output data types.
|
||||
|
||||
Includes functionality for ``"*"`` (`ANY`) and ``"MULTI,TYPES"``.
|
||||
"""
|
||||
|
||||
STRING = "STRING"
|
||||
IMAGE = "IMAGE"
|
||||
MASK = "MASK"
|
||||
LATENT = "LATENT"
|
||||
BOOLEAN = "BOOLEAN"
|
||||
INT = "INT"
|
||||
FLOAT = "FLOAT"
|
||||
CONDITIONING = "CONDITIONING"
|
||||
SAMPLER = "SAMPLER"
|
||||
SIGMAS = "SIGMAS"
|
||||
GUIDER = "GUIDER"
|
||||
NOISE = "NOISE"
|
||||
CLIP = "CLIP"
|
||||
CONTROL_NET = "CONTROL_NET"
|
||||
VAE = "VAE"
|
||||
MODEL = "MODEL"
|
||||
CLIP_VISION = "CLIP_VISION"
|
||||
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
|
||||
STYLE_MODEL = "STYLE_MODEL"
|
||||
GLIGEN = "GLIGEN"
|
||||
UPSCALE_MODEL = "UPSCALE_MODEL"
|
||||
AUDIO = "AUDIO"
|
||||
WEBCAM = "WEBCAM"
|
||||
POINT = "POINT"
|
||||
FACE_ANALYSIS = "FACE_ANALYSIS"
|
||||
BBOX = "BBOX"
|
||||
SEGS = "SEGS"
|
||||
|
||||
ANY = "*"
|
||||
"""Always matches any type, but at a price.
|
||||
|
||||
Causes some functionality issues (e.g. reroutes, link types), and should be avoided whenever possible.
|
||||
"""
|
||||
NUMBER = "FLOAT,INT"
|
||||
"""A float or an int - could be either"""
|
||||
PRIMITIVE = "STRING,FLOAT,INT,BOOLEAN"
|
||||
"""Could be any of: string, float, int, or bool"""
|
||||
|
||||
def __ne__(self, value: object) -> bool:
|
||||
if self == "*" or value == "*":
|
||||
return False
|
||||
if not isinstance(value, str):
|
||||
return True
|
||||
a = frozenset(self.split(","))
|
||||
b = frozenset(value.split(","))
|
||||
return not (b.issubset(a) or a.issubset(b))
|
||||
|
||||
|
||||
class InputTypeOptions(TypedDict):
|
||||
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
||||
|
||||
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes
|
||||
"""
|
||||
|
||||
default: bool | str | float | int | list | tuple
|
||||
"""The default value of the widget"""
|
||||
defaultInput: bool
|
||||
"""Defaults to an input slot rather than a widget"""
|
||||
forceInput: bool
|
||||
"""`defaultInput` and also don't allow converting to a widget"""
|
||||
lazy: bool
|
||||
"""Declares that this input uses lazy evaluation"""
|
||||
rawLink: bool
|
||||
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
||||
tooltip: str
|
||||
"""Tooltip for the input (or widget), shown on pointer hover"""
|
||||
# class InputTypeNumber(InputTypeOptions):
|
||||
# default: float | int
|
||||
min: float
|
||||
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
|
||||
max: float
|
||||
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
|
||||
step: float
|
||||
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
|
||||
round: float
|
||||
"""Floats are rounded by this value (``FLOAT``)"""
|
||||
# class InputTypeBoolean(InputTypeOptions):
|
||||
# default: bool
|
||||
label_on: str
|
||||
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
||||
label_on: str
|
||||
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
||||
# class InputTypeString(InputTypeOptions):
|
||||
# default: str
|
||||
multiline: bool
|
||||
"""Use a multiline text box (``STRING``)"""
|
||||
placeholder: str
|
||||
"""Placeholder text to display in the UI when empty (``STRING``)"""
|
||||
# Deprecated:
|
||||
# defaultVal: str
|
||||
dynamicPrompts: bool
|
||||
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
|
||||
|
||||
node_id: Literal["UNIQUE_ID"]
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
unique_id: Literal["UNIQUE_ID"]
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
prompt: Literal["PROMPT"]
|
||||
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
|
||||
extra_pnginfo: Literal["EXTRA_PNGINFO"]
|
||||
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
|
||||
dynprompt: Literal["DYNPROMPT"]
|
||||
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
|
||||
|
||||
|
||||
class InputTypeDict(TypedDict):
|
||||
"""Provides type hinting for node INPUT_TYPES.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs
|
||||
"""
|
||||
|
||||
required: dict[str, tuple[IO, InputTypeOptions]]
|
||||
"""Describes all inputs that must be connected for the node to execute."""
|
||||
optional: dict[str, tuple[IO, InputTypeOptions]]
|
||||
"""Describes inputs which do not need to be connected."""
|
||||
hidden: HiddenInputTypeDict
|
||||
"""Offers advanced functionality and server-client communication.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
"""
|
||||
|
||||
|
||||
class ComfyNodeABC(ABC):
|
||||
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview
|
||||
"""
|
||||
|
||||
DESCRIPTION: str
|
||||
"""Node description, shown as a tooltip when hovering over the node.
|
||||
|
||||
Usage::
|
||||
|
||||
# Explicitly define the description
|
||||
DESCRIPTION = "Example description here."
|
||||
|
||||
# Use the docstring of the node class.
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
"""
|
||||
CATEGORY: str
|
||||
"""The category of the node, as per the "Add Node" menu.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category
|
||||
"""
|
||||
EXPERIMENTAL: bool
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
"""Defines node inputs.
|
||||
|
||||
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
|
||||
* The ``optional`` key can be added to describe inputs which do not need to be connected.
|
||||
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types
|
||||
"""
|
||||
return {"required": {}}
|
||||
|
||||
OUTPUT_NODE: bool
|
||||
"""Flags this node as an output node, causing any inputs it requires to be executed.
|
||||
|
||||
If a node is not connected to any output nodes, that node will not be executed. Usage::
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
From the docs:
|
||||
|
||||
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node
|
||||
"""
|
||||
INPUT_IS_LIST: bool
|
||||
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
|
||||
|
||||
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
|
||||
|
||||
From the docs:
|
||||
|
||||
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
"""
|
||||
OUTPUT_IS_LIST: tuple[bool]
|
||||
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
|
||||
|
||||
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
|
||||
|
||||
A ``tuple[bool]``, where the items match those in `RETURN_TYPES`::
|
||||
|
||||
RETURN_TYPES = (IO.INT, IO.INT, IO.STRING)
|
||||
OUTPUT_IS_LIST = (True, True, False) # The string output will be handled normally
|
||||
|
||||
From the docs:
|
||||
|
||||
In order to tell Comfy that the list being returned should not be wrapped, but treated as a series of data for sequential processing,
|
||||
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
|
||||
specifying which outputs which should be so treated.
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
|
||||
"""
|
||||
|
||||
RETURN_TYPES: tuple[IO]
|
||||
"""A tuple representing the outputs of this node.
|
||||
|
||||
Usage::
|
||||
|
||||
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types
|
||||
"""
|
||||
RETURN_NAMES: tuple[str]
|
||||
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names
|
||||
"""
|
||||
OUTPUT_TOOLTIPS: tuple[str]
|
||||
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
|
||||
FUNCTION: str
|
||||
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function
|
||||
"""
|
||||
|
||||
|
||||
class CheckLazyMixin:
|
||||
"""Provides a basic check_lazy_status implementation and type hinting for nodes that use lazy inputs."""
|
||||
|
||||
def check_lazy_status(self, **kwargs) -> list[str]:
|
||||
"""Returns a list of input names that should be evaluated.
|
||||
|
||||
This basic mixin impl. requires all inputs.
|
||||
|
||||
:kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \
|
||||
When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``.
|
||||
|
||||
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
|
||||
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
|
||||
"""
|
||||
|
||||
need = [name for name in kwargs if kwargs[name] is None]
|
||||
return need
|
||||
@@ -35,6 +35,10 @@ import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.hooks import HookGroup
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
@@ -78,6 +82,8 @@ class ControlBase:
|
||||
self.concat_mask = False
|
||||
self.extra_concat_orig = []
|
||||
self.extra_concat = None
|
||||
self.extra_hooks: HookGroup = None
|
||||
self.preprocess_image = lambda a: a
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||
self.cond_hint_original = cond_hint
|
||||
@@ -114,6 +120,14 @@ class ControlBase:
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
out.append(self.extra_hooks)
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_extra_hooks()
|
||||
return out
|
||||
|
||||
def copy_to(self, c):
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
@@ -129,6 +143,8 @@ class ControlBase:
|
||||
c.strength_type = self.strength_type
|
||||
c.concat_mask = self.concat_mask
|
||||
c.extra_concat_orig = self.extra_concat_orig.copy()
|
||||
c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None
|
||||
c.preprocess_image = self.preprocess_image
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
if self.previous_controlnet is not None:
|
||||
@@ -181,7 +197,7 @@ class ControlBase:
|
||||
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
|
||||
super().__init__()
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
@@ -196,11 +212,12 @@ class ControlNet(ControlBase):
|
||||
self.extra_conds += extra_conds
|
||||
self.strength_type = strength_type
|
||||
self.concat_mask = concat_mask
|
||||
self.preprocess_image = preprocess_image
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
control_prev = None
|
||||
if self.previous_controlnet is not None:
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
|
||||
if self.timestep_range is not None:
|
||||
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||
@@ -224,6 +241,7 @@ class ControlNet(ControlBase):
|
||||
if self.latent_format is not None:
|
||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||
if self.vae is not None:
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
|
||||
@@ -427,6 +445,7 @@ def controlnet_load_state_dict(control_model, sd):
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
return control_model
|
||||
|
||||
|
||||
def load_controlnet_mmdit(sd, model_options={}):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||
@@ -448,6 +467,82 @@ def load_controlnet_mmdit(sd, model_options={}):
|
||||
return control
|
||||
|
||||
|
||||
class ControlNetSD35(ControlNet):
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
if self.control_model.double_y_emb:
|
||||
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||
else:
|
||||
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def copy(self):
|
||||
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def load_controlnet_sd35(sd, model_options={}):
|
||||
control_type = -1
|
||||
if "control_type" in sd:
|
||||
control_type = round(sd.pop("control_type").item())
|
||||
|
||||
# blur_cnet = control_type == 0
|
||||
canny_cnet = control_type == 1
|
||||
depth_cnet = control_type == 2
|
||||
|
||||
new_sd = {}
|
||||
for k in comfy.utils.MMDIT_MAP_BASIC:
|
||||
if k[1] in sd:
|
||||
new_sd[k[0]] = sd.pop(k[1])
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
sd = new_sd
|
||||
|
||||
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
|
||||
depth = y_emb_shape[0] // 64
|
||||
hidden_size = 64 * depth
|
||||
num_heads = depth
|
||||
head_dim = hidden_size // num_heads
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
offload_device = comfy.model_management.unet_offload_device()
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
|
||||
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
|
||||
patch_size=2,
|
||||
in_chans=16,
|
||||
num_layers=num_blocks,
|
||||
main_model_double=depth,
|
||||
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
|
||||
attention_head_dim=head_dim,
|
||||
num_attention_heads=num_heads,
|
||||
adm_in_channels=2048,
|
||||
device=offload_device,
|
||||
dtype=unet_dtype,
|
||||
operations=operations)
|
||||
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
|
||||
latent_format = comfy.latent_formats.SD3()
|
||||
preprocess_image = lambda a: a
|
||||
if canny_cnet:
|
||||
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
|
||||
elif depth_cnet:
|
||||
preprocess_image = lambda a: 1.0 - a
|
||||
|
||||
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
|
||||
return control
|
||||
|
||||
|
||||
|
||||
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
||||
|
||||
@@ -560,7 +655,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
||||
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
|
||||
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
||||
else:
|
||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
@@ -674,10 +772,10 @@ class T2IAdapter(ControlBase):
|
||||
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
|
||||
return width, height
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
control_prev = None
|
||||
if self.previous_controlnet is not None:
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
|
||||
if self.timestep_range is not None:
|
||||
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||
|
||||
690
comfy/hooks.py
Normal file
690
comfy/hooks.py
Normal file
@@ -0,0 +1,690 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
import enum
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher, PatcherInjection
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.sd import CLIP
|
||||
import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from node_helpers import conditioning_set_values
|
||||
|
||||
class EnumHookMode(enum.Enum):
|
||||
MinVram = "minvram"
|
||||
MaxSpeed = "maxspeed"
|
||||
|
||||
class EnumHookType(enum.Enum):
|
||||
Weight = "weight"
|
||||
Patch = "patch"
|
||||
ObjectPatch = "object_patch"
|
||||
AddModels = "add_models"
|
||||
Callbacks = "callbacks"
|
||||
Wrappers = "wrappers"
|
||||
SetInjections = "add_injections"
|
||||
|
||||
class EnumWeightTarget(enum.Enum):
|
||||
Model = "model"
|
||||
Clip = "clip"
|
||||
|
||||
class _HookRef:
|
||||
pass
|
||||
|
||||
# NOTE: this is an example of how the should_register function should look
|
||||
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
return True
|
||||
|
||||
|
||||
class Hook:
|
||||
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
||||
hook_keyframe: 'HookKeyframeGroup'=None):
|
||||
self.hook_type = hook_type
|
||||
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
||||
self.hook_id = hook_id
|
||||
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
||||
self.custom_should_register = default_should_register
|
||||
self.auto_apply_to_nonpositive = False
|
||||
|
||||
@property
|
||||
def strength(self):
|
||||
return self.hook_keyframe.strength
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
self.reset()
|
||||
self.hook_keyframe.initialize_timesteps(model)
|
||||
|
||||
def reset(self):
|
||||
self.hook_keyframe.reset()
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: Hook = subtype()
|
||||
c.hook_type = self.hook_type
|
||||
c.hook_ref = self.hook_ref
|
||||
c.hook_id = self.hook_id
|
||||
c.hook_keyframe = self.hook_keyframe
|
||||
c.custom_should_register = self.custom_should_register
|
||||
# TODO: make this do something
|
||||
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
|
||||
return c
|
||||
|
||||
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
return self.custom_should_register(self, model, model_options, target, registered)
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
||||
|
||||
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
|
||||
pass
|
||||
|
||||
def __eq__(self, other: 'Hook'):
|
||||
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.hook_ref)
|
||||
|
||||
class WeightHook(Hook):
|
||||
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
||||
super().__init__(hook_type=EnumHookType.Weight)
|
||||
self.weights: dict = None
|
||||
self.weights_clip: dict = None
|
||||
self.need_weight_init = True
|
||||
self._strength_model = strength_model
|
||||
self._strength_clip = strength_clip
|
||||
|
||||
@property
|
||||
def strength_model(self):
|
||||
return self._strength_model * self.strength
|
||||
|
||||
@property
|
||||
def strength_clip(self):
|
||||
return self._strength_clip * self.strength
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
return False
|
||||
weights = None
|
||||
if target == EnumWeightTarget.Model:
|
||||
strength = self._strength_model
|
||||
else:
|
||||
strength = self._strength_clip
|
||||
|
||||
if self.need_weight_init:
|
||||
key_map = {}
|
||||
if target == EnumWeightTarget.Model:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
else:
|
||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||
else:
|
||||
if target == EnumWeightTarget.Model:
|
||||
weights = self.weights
|
||||
else:
|
||||
weights = self.weights_clip
|
||||
k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||
registered.append(self)
|
||||
return True
|
||||
# TODO: add logs about any keys that were not applied
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WeightHook = super().clone(subtype)
|
||||
c.weights = self.weights
|
||||
c.weights_clip = self.weights_clip
|
||||
c.need_weight_init = self.need_weight_init
|
||||
c._strength_model = self._strength_model
|
||||
c._strength_clip = self._strength_clip
|
||||
return c
|
||||
|
||||
class PatchHook(Hook):
|
||||
def __init__(self):
|
||||
super().__init__(hook_type=EnumHookType.Patch)
|
||||
self.patches: dict = None
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: PatchHook = super().clone(subtype)
|
||||
c.patches = self.patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class ObjectPatchHook(Hook):
|
||||
def __init__(self):
|
||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||
self.object_patches: dict = None
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: ObjectPatchHook = super().clone(subtype)
|
||||
c.object_patches = self.object_patches
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class AddModelsHook(Hook):
|
||||
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
|
||||
super().__init__(hook_type=EnumHookType.AddModels)
|
||||
self.key = key
|
||||
self.models = models
|
||||
self.append_when_same = True
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: AddModelsHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.models = self.models.copy() if self.models else self.models
|
||||
c.append_when_same = self.append_when_same
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class CallbackHook(Hook):
|
||||
def __init__(self, key: str=None, callback: Callable=None):
|
||||
super().__init__(hook_type=EnumHookType.Callbacks)
|
||||
self.key = key
|
||||
self.callback = callback
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: CallbackHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.callback = self.callback
|
||||
return c
|
||||
# TODO: add functionality
|
||||
|
||||
class WrapperHook(Hook):
|
||||
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
|
||||
super().__init__(hook_type=EnumHookType.Wrappers)
|
||||
self.wrappers_dict = wrappers_dict
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: WrapperHook = super().clone(subtype)
|
||||
c.wrappers_dict = self.wrappers_dict
|
||||
return c
|
||||
|
||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||
if not self.should_register(model, model_options, target, registered):
|
||||
return False
|
||||
add_model_options = {"transformer_options": self.wrappers_dict}
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
||||
registered.append(self)
|
||||
return True
|
||||
|
||||
class SetInjectionsHook(Hook):
|
||||
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
|
||||
super().__init__(hook_type=EnumHookType.SetInjections)
|
||||
self.key = key
|
||||
self.injections = injections
|
||||
|
||||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: SetInjectionsHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.injections = self.injections.copy() if self.injections else self.injections
|
||||
return c
|
||||
|
||||
def add_hook_injections(self, model: 'ModelPatcher'):
|
||||
# TODO: add functionality
|
||||
pass
|
||||
|
||||
class HookGroup:
|
||||
def __init__(self):
|
||||
self.hooks: list[Hook] = []
|
||||
|
||||
def add(self, hook: Hook):
|
||||
if hook not in self.hooks:
|
||||
self.hooks.append(hook)
|
||||
|
||||
def contains(self, hook: Hook):
|
||||
return hook in self.hooks
|
||||
|
||||
def clone(self):
|
||||
c = HookGroup()
|
||||
for hook in self.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def clone_and_combine(self, other: 'HookGroup'):
|
||||
c = self.clone()
|
||||
if other is not None:
|
||||
for hook in other.hooks:
|
||||
c.add(hook.clone())
|
||||
return c
|
||||
|
||||
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
|
||||
if hook_kf is None:
|
||||
hook_kf = HookKeyframeGroup()
|
||||
else:
|
||||
hook_kf = hook_kf.clone()
|
||||
for hook in self.hooks:
|
||||
hook.hook_keyframe = hook_kf
|
||||
|
||||
def get_dict_repr(self):
|
||||
d: dict[EnumHookType, dict[Hook, None]] = {}
|
||||
for hook in self.hooks:
|
||||
with_type = d.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
return d
|
||||
|
||||
def get_hooks_for_clip_schedule(self):
|
||||
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
||||
for hook in self.hooks:
|
||||
# only care about WeightHooks, for now
|
||||
if hook.hook_type == EnumHookType.Weight:
|
||||
hook_schedule = []
|
||||
# if no hook keyframes, assign default value
|
||||
if len(hook.hook_keyframe.keyframes) == 0:
|
||||
hook_schedule.append(((0.0, 1.0), None))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
continue
|
||||
# find ranges of values
|
||||
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
||||
for keyframe in hook.hook_keyframe.keyframes:
|
||||
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
||||
prev_keyframe = keyframe
|
||||
elif keyframe.start_percent == prev_keyframe.start_percent:
|
||||
prev_keyframe = keyframe
|
||||
# create final range, assuming last start_percent was not 1.0
|
||||
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
||||
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
||||
scheduled_hooks[hook] = hook_schedule
|
||||
# hooks should not have their schedules in a list of tuples
|
||||
all_ranges: list[tuple[float, float]] = []
|
||||
for range_kfs in scheduled_hooks.values():
|
||||
for t_range, keyframe in range_kfs:
|
||||
all_ranges.append(t_range)
|
||||
# turn list of ranges into boundaries
|
||||
boundaries_set = set(itertools.chain.from_iterable(all_ranges))
|
||||
boundaries_set.add(0.0)
|
||||
boundaries = sorted(boundaries_set)
|
||||
real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)]
|
||||
# with real ranges defined, give appropriate hooks w/ keyframes for each range
|
||||
scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = []
|
||||
for t_range in real_ranges:
|
||||
hooks_schedule = []
|
||||
for hook, val in scheduled_hooks.items():
|
||||
keyframe = None
|
||||
# check if is a keyframe that works for the current t_range
|
||||
for stored_range, stored_kf in val:
|
||||
# if stored start is less than current end, then fits - give it assigned keyframe
|
||||
if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]:
|
||||
keyframe = stored_kf
|
||||
break
|
||||
hooks_schedule.append((hook, keyframe))
|
||||
scheduled_keyframes.append((t_range, hooks_schedule))
|
||||
return scheduled_keyframes
|
||||
|
||||
def reset(self):
|
||||
for hook in self.hooks:
|
||||
hook.reset()
|
||||
|
||||
@staticmethod
|
||||
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
|
||||
actual: list[HookGroup] = []
|
||||
for group in hooks_list:
|
||||
if group is not None:
|
||||
actual.append(group)
|
||||
if len(actual) < require_count:
|
||||
raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.")
|
||||
# if no hooks, then return None
|
||||
if len(actual) == 0:
|
||||
return None
|
||||
# if only 1 hook, just return itself without cloning
|
||||
elif len(actual) == 1:
|
||||
return actual[0]
|
||||
final_hook: HookGroup = None
|
||||
for hook in actual:
|
||||
if final_hook is None:
|
||||
final_hook = hook.clone()
|
||||
else:
|
||||
final_hook = final_hook.clone_and_combine(hook)
|
||||
return final_hook
|
||||
|
||||
|
||||
class HookKeyframe:
|
||||
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
|
||||
self.strength = strength
|
||||
# scheduling
|
||||
self.start_percent = float(start_percent)
|
||||
self.start_t = 999999999.9
|
||||
self.guarantee_steps = guarantee_steps
|
||||
|
||||
def clone(self):
|
||||
c = HookKeyframe(strength=self.strength,
|
||||
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
||||
c.start_t = self.start_t
|
||||
return c
|
||||
|
||||
class HookKeyframeGroup:
|
||||
def __init__(self):
|
||||
self.keyframes: list[HookKeyframe] = []
|
||||
self._current_keyframe: HookKeyframe = None
|
||||
self._current_used_steps = 0
|
||||
self._current_index = 0
|
||||
self._current_strength = None
|
||||
self._curr_t = -1.
|
||||
|
||||
# properties shadow those of HookWeightsKeyframe
|
||||
@property
|
||||
def strength(self):
|
||||
if self._current_keyframe is not None:
|
||||
return self._current_keyframe.strength
|
||||
return 1.0
|
||||
|
||||
def reset(self):
|
||||
self._current_keyframe = None
|
||||
self._current_used_steps = 0
|
||||
self._current_index = 0
|
||||
self._current_strength = None
|
||||
self.curr_t = -1.
|
||||
self._set_first_as_current()
|
||||
|
||||
def add(self, keyframe: HookKeyframe):
|
||||
# add to end of list, then sort
|
||||
self.keyframes.append(keyframe)
|
||||
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
|
||||
self._set_first_as_current()
|
||||
|
||||
def _set_first_as_current(self):
|
||||
if len(self.keyframes) > 0:
|
||||
self._current_keyframe = self.keyframes[0]
|
||||
else:
|
||||
self._current_keyframe = None
|
||||
|
||||
def has_index(self, index: int):
|
||||
return index >= 0 and index < len(self.keyframes)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.keyframes) == 0
|
||||
|
||||
def clone(self):
|
||||
c = HookKeyframeGroup()
|
||||
for keyframe in self.keyframes:
|
||||
c.keyframes.append(keyframe.clone())
|
||||
c._set_first_as_current()
|
||||
return c
|
||||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
for keyframe in self.keyframes:
|
||||
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
||||
|
||||
def prepare_current_keyframe(self, curr_t: float) -> bool:
|
||||
if self.is_empty():
|
||||
return False
|
||||
if curr_t == self._curr_t:
|
||||
return False
|
||||
prev_index = self._current_index
|
||||
prev_strength = self._current_strength
|
||||
# if met guaranteed steps, look for next keyframe in case need to switch
|
||||
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
|
||||
# if has next index, loop through and see if need to switch
|
||||
if self.has_index(self._current_index+1):
|
||||
for i in range(self._current_index+1, len(self.keyframes)):
|
||||
eval_c = self.keyframes[i]
|
||||
# check if start_t is greater or equal to curr_t
|
||||
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
|
||||
if eval_c.start_t >= curr_t:
|
||||
self._current_index = i
|
||||
self._current_strength = eval_c.strength
|
||||
self._current_keyframe = eval_c
|
||||
self._current_used_steps = 0
|
||||
# if guarantee_steps greater than zero, stop searching for other keyframes
|
||||
if self._current_keyframe.guarantee_steps > 0:
|
||||
break
|
||||
# if eval_c is outside the percent range, stop looking further
|
||||
else: break
|
||||
# update steps current context is used
|
||||
self._current_used_steps += 1
|
||||
# update current timestep this was performed on
|
||||
self._curr_t = curr_t
|
||||
# return True if keyframe changed, False if no change
|
||||
return prev_index != self._current_index and prev_strength != self._current_strength
|
||||
|
||||
|
||||
class InterpolationMethod:
|
||||
LINEAR = "linear"
|
||||
EASE_IN = "ease_in"
|
||||
EASE_OUT = "ease_out"
|
||||
EASE_IN_OUT = "ease_in_out"
|
||||
|
||||
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
|
||||
|
||||
@classmethod
|
||||
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
|
||||
diff = num_to - num_from
|
||||
if method == cls.LINEAR:
|
||||
weights = torch.linspace(num_from, num_to, length)
|
||||
elif method == cls.EASE_IN:
|
||||
index = torch.linspace(0, 1, length)
|
||||
weights = diff * np.power(index, 2) + num_from
|
||||
elif method == cls.EASE_OUT:
|
||||
index = torch.linspace(0, 1, length)
|
||||
weights = diff * (1 - np.power(1 - index, 2)) + num_from
|
||||
elif method == cls.EASE_IN_OUT:
|
||||
index = torch.linspace(0, 1, length)
|
||||
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
|
||||
else:
|
||||
raise ValueError(f"Unrecognized interpolation method '{method}'.")
|
||||
if reverse:
|
||||
weights = weights.flip(dims=(0,))
|
||||
return weights
|
||||
|
||||
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
||||
if not objects:
|
||||
return objects
|
||||
elif len(objects) <= 1:
|
||||
return [x for x in objects]
|
||||
# now that we know we have to sort, do it following these rules:
|
||||
# a) if objects have same value of attribute, maintain their relative order
|
||||
# b) perform sorting of the groups of objects with same attributes
|
||||
unique_attrs = {}
|
||||
for o in objects:
|
||||
val_attr = getattr(o, attr)
|
||||
attr_list: list = unique_attrs.get(val_attr, list())
|
||||
attr_list.append(o)
|
||||
if val_attr not in unique_attrs:
|
||||
unique_attrs[val_attr] = attr_list
|
||||
# now that we have the unique attr values grouped together in relative order, sort them by key
|
||||
sorted_attrs = dict(sorted(unique_attrs.items()))
|
||||
# now flatten out the dict into a list to return
|
||||
sorted_list = []
|
||||
for object_list in sorted_attrs.values():
|
||||
sorted_list.extend(object_list)
|
||||
return sorted_list
|
||||
|
||||
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||
hook_group.add(hook)
|
||||
hook.weights = lora
|
||||
return hook_group
|
||||
|
||||
def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float):
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
||||
hook_group.add(hook)
|
||||
patches_model = None
|
||||
patches_clip = None
|
||||
if weights_model is not None:
|
||||
patches_model = {}
|
||||
for key in weights_model:
|
||||
patches_model[key] = ("model_as_lora", (weights_model[key],))
|
||||
if weights_clip is not None:
|
||||
patches_clip = {}
|
||||
for key in weights_clip:
|
||||
patches_clip[key] = ("model_as_lora", (weights_clip[key],))
|
||||
hook.weights = patches_model
|
||||
hook.weights_clip = patches_clip
|
||||
hook.need_weight_init = False
|
||||
return hook_group
|
||||
|
||||
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
|
||||
if model is None:
|
||||
return None
|
||||
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
||||
if discard_model_sampling:
|
||||
# do not include ANY model_sampling components of the model that should act as a patch
|
||||
for key in list(patches_model.keys()):
|
||||
if key.startswith("model_sampling"):
|
||||
patches_model.pop(key, None)
|
||||
return patches_model
|
||||
|
||||
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
||||
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
|
||||
strength_model: float, strength_clip: float):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if clip is not None:
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
hook_group = HookGroup()
|
||||
hook = WeightHook()
|
||||
hook_group.add(hook)
|
||||
loaded: dict[str] = comfy.lora.load_lora(lora, key_map)
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model)
|
||||
else:
|
||||
k = ()
|
||||
new_modelpatcher = None
|
||||
|
||||
if clip is not None:
|
||||
new_clip = clip.clone()
|
||||
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
|
||||
else:
|
||||
k1 = ()
|
||||
new_clip = None
|
||||
k = set(k)
|
||||
k1 = set(k1)
|
||||
for x in loaded:
|
||||
if (x not in k) and (x not in k1):
|
||||
print(f"NOT LOADED {x}")
|
||||
return (new_modelpatcher, new_clip, hook_group)
|
||||
|
||||
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
|
||||
hooks_key = 'hooks'
|
||||
# if hooks only exist in one dict, do what's needed so that it ends up in c_dict
|
||||
if hooks_key not in values:
|
||||
return
|
||||
if hooks_key not in c_dict:
|
||||
hooks_value = values.get(hooks_key, None)
|
||||
if hooks_value is not None:
|
||||
c_dict[hooks_key] = hooks_value
|
||||
return
|
||||
# otherwise, need to combine with minimum duplication via cache
|
||||
hooks_tuple = (c_dict[hooks_key], values[hooks_key])
|
||||
cached_hooks = cache.get(hooks_tuple, None)
|
||||
if cached_hooks is None:
|
||||
new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1])
|
||||
cache[hooks_tuple] = new_hooks
|
||||
c_dict[hooks_key] = new_hooks
|
||||
else:
|
||||
c_dict[hooks_key] = cache[hooks_tuple]
|
||||
|
||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
|
||||
c = []
|
||||
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
for k in values:
|
||||
if append_hooks and k == 'hooks':
|
||||
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
|
||||
else:
|
||||
n[1][k] = values[k]
|
||||
c.append(n)
|
||||
|
||||
return c
|
||||
|
||||
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
|
||||
if hooks is None:
|
||||
return cond
|
||||
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
|
||||
|
||||
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
||||
if timestep_range is None:
|
||||
return cond
|
||||
return conditioning_set_values(cond, {"start_percent": timestep_range[0],
|
||||
"end_percent": timestep_range[1]})
|
||||
|
||||
def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float):
|
||||
if mask is None:
|
||||
return cond
|
||||
set_area_to_bounds = False
|
||||
if set_cond_area != 'default':
|
||||
set_area_to_bounds = True
|
||||
if len(mask.shape) < 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
return conditioning_set_values(cond, {'mask': mask,
|
||||
'set_area_to_bounds': set_area_to_bounds,
|
||||
'mask_strength': strength})
|
||||
|
||||
def combine_conditioning(conds: list):
|
||||
combined_conds = []
|
||||
for cond in conds:
|
||||
combined_conds.extend(cond)
|
||||
return combined_conds
|
||||
|
||||
def combine_with_new_conds(conds: list, new_conds: list):
|
||||
combined_conds = []
|
||||
for c, new_c in zip(conds, new_conds):
|
||||
combined_conds.append(combine_conditioning([c, new_c]))
|
||||
return combined_conds
|
||||
|
||||
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
final_conds = []
|
||||
for c in conds:
|
||||
# first, apply lora_hook to conditioning, if provided
|
||||
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
|
||||
# next, apply mask to conditioning
|
||||
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
||||
# apply timesteps, if present
|
||||
c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range)
|
||||
# finally, apply mask to conditioning and store
|
||||
final_conds.append(c)
|
||||
return final_conds
|
||||
|
||||
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
for c, masked_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
|
||||
# next, apply mask to new conditioning, if provided
|
||||
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
||||
# apply timesteps, if present
|
||||
masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range)
|
||||
# finally, combine with existing conditioning and store
|
||||
combined_conds.append(combine_conditioning([c, masked_c]))
|
||||
return combined_conds
|
||||
|
||||
def set_default_conds_and_combine(conds: list, new_conds: list,
|
||||
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
||||
combined_conds = []
|
||||
for c, new_c in zip(conds, new_conds):
|
||||
# first, apply lora_hook to new conditioning, if provided
|
||||
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
|
||||
# next, add default_cond key to cond so that during sampling, it can be identified
|
||||
new_c = conditioning_set_values(new_c, {'default': True})
|
||||
# apply timesteps, if present
|
||||
new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range)
|
||||
# finally, combine with existing conditioning and store
|
||||
combined_conds.append(combine_conditioning([c, new_c]))
|
||||
return combined_conds
|
||||
@@ -175,12 +175,14 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
|
||||
if sigma_down == 0:
|
||||
x = denoised
|
||||
else:
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -192,19 +194,22 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i+1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i+1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
# Euler method
|
||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
||||
if sigmas[i + 1] > 0 and eta > 0:
|
||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i + 1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i + 1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
|
||||
# Euler method
|
||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
||||
if eta > 0:
|
||||
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -280,6 +285,9 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
||||
return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@@ -306,6 +314,38 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i+1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i+1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver-2
|
||||
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
return x
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j):
|
||||
if order - 1 > i:
|
||||
|
||||
@@ -190,7 +190,21 @@ class Mochi(LatentFormat):
|
||||
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
|
||||
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
self.latent_rgb_factors = None #TODO
|
||||
self.latent_rgb_factors =[
|
||||
[-0.0069, -0.0045, 0.0018],
|
||||
[ 0.0154, -0.0692, -0.0274],
|
||||
[ 0.0333, 0.0019, 0.0206],
|
||||
[-0.1390, 0.0628, 0.1678],
|
||||
[-0.0725, 0.0134, -0.1898],
|
||||
[ 0.0074, -0.0270, -0.0209],
|
||||
[-0.0176, -0.0277, -0.0221],
|
||||
[ 0.5294, 0.5204, 0.3852],
|
||||
[-0.0326, -0.0446, -0.0143],
|
||||
[-0.0659, 0.0153, -0.0153],
|
||||
[ 0.0185, -0.0217, 0.0014],
|
||||
[-0.0396, -0.0495, -0.0281]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
|
||||
self.taesd_decoder_name = None #TODO
|
||||
|
||||
def process_in(self, latent):
|
||||
@@ -202,3 +216,139 @@ class Mochi(LatentFormat):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
|
||||
[ 8.6031e-02, 6.5813e-02, 9.5409e-04],
|
||||
[-1.2576e-02, -7.5734e-03, -4.0528e-03],
|
||||
[ 9.4063e-03, -2.1688e-03, 2.6093e-03],
|
||||
[ 3.7636e-03, 1.2765e-02, 9.1548e-03],
|
||||
[ 2.1024e-02, -5.2973e-03, 3.4373e-03],
|
||||
[-8.8896e-03, -1.9703e-02, -1.8761e-02],
|
||||
[-1.3160e-02, -1.0523e-02, 1.9709e-03],
|
||||
[-1.5152e-03, -6.9891e-03, -7.5810e-03],
|
||||
[-1.7247e-03, 4.6560e-04, -3.3839e-03],
|
||||
[ 1.3617e-02, 4.7077e-03, -2.0045e-03],
|
||||
[ 1.0256e-02, 7.7318e-03, 1.3948e-02],
|
||||
[-1.6108e-02, -6.2151e-03, 1.1561e-03],
|
||||
[ 7.3407e-03, 1.5628e-02, 4.4865e-04],
|
||||
[ 9.5357e-04, -2.9518e-03, -1.4760e-02],
|
||||
[ 1.9143e-02, 1.0868e-02, 1.2264e-02],
|
||||
[ 4.4575e-03, 3.6682e-05, -6.8508e-03],
|
||||
[-4.5681e-04, 3.2570e-03, 7.7929e-03],
|
||||
[ 3.3902e-02, 3.3405e-02, 3.7454e-02],
|
||||
[-2.3001e-02, -2.4877e-03, -3.1033e-03],
|
||||
[ 5.0265e-02, 3.8841e-02, 3.3539e-02],
|
||||
[-4.1018e-03, -1.1095e-03, 1.5859e-03],
|
||||
[-1.2689e-01, -1.3107e-01, -2.1005e-01],
|
||||
[ 2.6276e-02, 1.4189e-02, -3.5963e-03],
|
||||
[-4.8679e-03, 8.8486e-03, 7.8029e-03],
|
||||
[-1.6610e-03, -4.8597e-03, -5.2060e-03],
|
||||
[-2.1010e-03, 2.3610e-03, 9.3796e-03],
|
||||
[-2.2482e-02, -2.1305e-02, -1.5087e-02],
|
||||
[-1.5753e-02, -1.0646e-02, -6.5083e-03],
|
||||
[-4.6975e-03, 5.0288e-03, -6.7390e-03],
|
||||
[ 1.1951e-02, 2.0712e-02, 1.6191e-02],
|
||||
[-6.3704e-03, -8.4827e-03, -9.5483e-03],
|
||||
[ 7.2610e-03, -9.9326e-03, -2.2978e-02],
|
||||
[-9.1904e-04, 6.2882e-03, 9.5720e-03],
|
||||
[-3.7178e-02, -3.7123e-02, -5.6713e-02],
|
||||
[-1.3373e-01, -1.0720e-01, -5.3801e-02],
|
||||
[-5.3702e-03, 8.1256e-03, 8.8397e-03],
|
||||
[-1.5247e-01, -2.1437e-01, -2.1843e-01],
|
||||
[ 3.1441e-02, 7.0335e-03, -9.7541e-03],
|
||||
[ 2.1528e-03, -8.9817e-03, -2.1023e-02],
|
||||
[ 3.8461e-03, -5.8957e-03, -1.5014e-02],
|
||||
[-4.3470e-03, -1.2940e-02, -1.5972e-02],
|
||||
[-5.4781e-03, -1.0842e-02, -3.0204e-03],
|
||||
[-6.5347e-03, 3.0806e-03, -1.0163e-02],
|
||||
[-5.0414e-03, -7.1503e-03, -8.9686e-04],
|
||||
[-8.5851e-03, -2.4351e-03, 1.0674e-03],
|
||||
[-9.0016e-03, -9.6493e-03, 1.5692e-03],
|
||||
[ 5.0914e-03, 1.2099e-02, 1.9968e-02],
|
||||
[ 1.3758e-02, 1.1669e-02, 8.1958e-03],
|
||||
[-1.0518e-02, -1.1575e-02, -4.1307e-03],
|
||||
[-2.8410e-02, -3.1266e-02, -2.2149e-02],
|
||||
[ 2.9336e-03, 3.6511e-02, 1.8717e-02],
|
||||
[-1.6703e-02, -1.6696e-02, -4.4529e-03],
|
||||
[ 4.8818e-02, 4.0063e-02, 8.7410e-03],
|
||||
[-1.5066e-02, -5.7328e-04, 2.9785e-03],
|
||||
[-1.7613e-02, -8.1034e-03, 1.3086e-02],
|
||||
[-9.2633e-03, 1.0803e-02, -6.3489e-03],
|
||||
[ 3.0851e-03, 4.7750e-04, 1.2347e-02],
|
||||
[-2.2785e-02, -2.3043e-02, -2.6005e-02],
|
||||
[-2.4787e-02, -1.5389e-02, -2.2104e-02],
|
||||
[-2.3572e-02, 1.0544e-03, 1.2361e-02],
|
||||
[-7.8915e-03, -1.2271e-03, -6.0968e-03],
|
||||
[-1.1478e-02, -1.2543e-03, 6.2679e-03],
|
||||
[-5.4229e-02, 2.6644e-02, 6.3394e-03],
|
||||
[ 4.4216e-03, -7.3338e-03, -1.0464e-02],
|
||||
[-4.5013e-03, 1.6082e-03, 1.4420e-02],
|
||||
[ 1.3673e-02, 8.8877e-03, 4.1253e-03],
|
||||
[-1.0145e-02, 9.0072e-03, 1.5695e-02],
|
||||
[-5.6234e-03, 1.1847e-03, 8.1261e-03],
|
||||
[-3.7171e-03, -5.3538e-03, 1.2590e-03],
|
||||
[ 2.9476e-02, 2.1424e-02, 3.0424e-02],
|
||||
[-3.4925e-02, -2.4340e-02, -2.5316e-02],
|
||||
[-3.4127e-02, -2.2406e-02, -1.0589e-02],
|
||||
[-1.7342e-02, -1.3249e-02, -1.0719e-02],
|
||||
[-2.1478e-03, -8.6051e-03, -2.9878e-03],
|
||||
[ 1.2089e-03, -4.2391e-03, -6.8569e-03],
|
||||
[ 9.0411e-04, -6.6886e-03, -6.7547e-05],
|
||||
[ 1.6048e-02, -1.0057e-02, -2.8929e-02],
|
||||
[ 1.2290e-03, 1.0163e-02, 1.8861e-02],
|
||||
[ 1.7264e-02, 2.7257e-04, 1.3785e-02],
|
||||
[-1.3482e-02, -3.6427e-03, 6.7481e-04],
|
||||
[ 4.6782e-03, -5.2423e-03, 2.4467e-03],
|
||||
[-5.9113e-03, -6.2244e-03, -1.8162e-03],
|
||||
[ 1.5496e-02, 1.4582e-02, 1.9514e-03],
|
||||
[ 7.4958e-03, 1.5886e-03, -8.2305e-03],
|
||||
[ 1.9086e-02, 1.6360e-03, -3.9674e-03],
|
||||
[-5.7021e-03, -2.7307e-03, -4.1066e-03],
|
||||
[ 1.7450e-03, 1.4602e-02, 2.5794e-02],
|
||||
[-8.2788e-04, 2.2902e-03, 4.5161e-03],
|
||||
[ 1.1632e-02, 8.9193e-03, -7.2813e-03],
|
||||
[ 7.5721e-03, 2.6784e-03, 1.1393e-02],
|
||||
[ 5.1939e-03, 3.6903e-03, 1.4049e-02],
|
||||
[-1.8383e-02, -2.2529e-02, -2.4477e-02],
|
||||
[ 5.8842e-04, -5.7874e-03, -1.4770e-02],
|
||||
[-1.6125e-02, -8.6101e-03, -1.4533e-02],
|
||||
[ 2.0540e-02, 2.0729e-02, 6.4338e-03],
|
||||
[ 3.3587e-03, -1.1226e-02, -1.6444e-02],
|
||||
[-1.4742e-03, -1.0489e-02, 1.7097e-03],
|
||||
[ 2.8130e-02, 2.3546e-02, 3.2791e-02],
|
||||
[-1.8532e-02, -1.2842e-02, -8.7756e-03],
|
||||
[-8.0533e-03, -1.0771e-02, -1.7536e-02],
|
||||
[-3.9009e-03, 1.6150e-02, 3.3359e-02],
|
||||
[-7.4554e-03, -1.4154e-02, -6.1910e-03],
|
||||
[ 3.4734e-03, -1.1370e-02, -1.0581e-02],
|
||||
[ 1.1476e-02, 3.9281e-03, 2.8231e-03],
|
||||
[ 7.1639e-03, -1.4741e-03, -3.8066e-03],
|
||||
[ 2.2250e-03, -8.7552e-03, -9.5719e-03],
|
||||
[ 2.4146e-02, 2.1696e-02, 2.8056e-02],
|
||||
[-5.4365e-03, -2.4291e-02, -1.7802e-02],
|
||||
[ 7.4263e-03, 1.0510e-02, 1.2705e-02],
|
||||
[ 6.2669e-03, 6.2658e-03, 1.9211e-02],
|
||||
[ 1.6378e-02, 9.4933e-03, 6.6971e-03],
|
||||
[ 1.7173e-02, 2.3601e-02, 2.3296e-02],
|
||||
[-1.4568e-02, -9.8279e-03, -1.1556e-02],
|
||||
[ 1.4431e-02, 1.4430e-02, 6.6362e-03],
|
||||
[-6.8230e-03, 1.8863e-02, 1.4555e-02],
|
||||
[ 6.1156e-03, 3.4700e-03, -2.6662e-03],
|
||||
[-2.6983e-03, -5.9402e-03, -9.2276e-03],
|
||||
[ 1.0235e-02, 7.4173e-03, -7.6243e-03],
|
||||
[-1.3255e-02, 1.9322e-02, -9.2153e-04],
|
||||
[ 2.4222e-03, -4.8039e-03, -1.5759e-02],
|
||||
[ 2.6244e-02, 2.5951e-02, 2.0249e-02],
|
||||
[ 1.5711e-02, 1.8498e-02, 2.7407e-03],
|
||||
[-2.1714e-03, 4.7214e-03, -2.2443e-02],
|
||||
[-7.4747e-03, 7.4166e-03, 1.4430e-02],
|
||||
[-8.3906e-03, -7.9776e-03, 9.7927e-03],
|
||||
[ 3.8321e-02, 9.6622e-03, -1.9268e-02],
|
||||
[-1.4605e-02, -6.7032e-03, 3.9675e-03]
|
||||
]
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||
|
||||
@@ -612,7 +612,9 @@ class ContinuousTransformer(nn.Module):
|
||||
return_info = False,
|
||||
**kwargs
|
||||
):
|
||||
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
context = kwargs["context"]
|
||||
|
||||
info = {
|
||||
"hidden_states": [],
|
||||
@@ -643,9 +645,19 @@ class ContinuousTransformer(nn.Module):
|
||||
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||
x = x + self.pos_emb(x)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
# Iterate over the transformer layers
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.layers):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
|
||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
|
||||
if return_info:
|
||||
@@ -874,7 +886,6 @@ class AudioDiffusionTransformer(nn.Module):
|
||||
mask=None,
|
||||
return_info=False,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
**kwargs):
|
||||
return self._forward(
|
||||
x,
|
||||
|
||||
@@ -437,7 +437,8 @@ class MMDiT(nn.Module):
|
||||
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
|
||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
# patchify x, add PE
|
||||
b, c, h, w = x.shape
|
||||
|
||||
@@ -458,15 +459,36 @@ class MMDiT(nn.Module):
|
||||
|
||||
global_cond = self.t_embedder(t, x.dtype) # B, D
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
if len(self.double_layers) > 0:
|
||||
for layer in self.double_layers:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.double_layers):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = layer(args["txt"],
|
||||
args["img"],
|
||||
args["vec"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
|
||||
c = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
|
||||
if len(self.single_layers) > 0:
|
||||
c_len = c.size(1)
|
||||
cx = torch.cat([c, x], dim=1)
|
||||
for layer in self.single_layers:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.single_layers):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], args["vec"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
|
||||
cx = out["img"]
|
||||
else:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
|
||||
x = cx[:, c_len:]
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
import comfy.ops
|
||||
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
padding_mode = "reflect"
|
||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||
|
||||
@@ -20,6 +20,7 @@ import comfy.ldm.common_dit
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
@@ -29,6 +30,7 @@ class FluxParams:
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
@@ -43,8 +45,9 @@ class Flux(nn.Module):
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels * 2 * 2
|
||||
self.out_channels = self.in_channels
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels * params.patch_size * params.patch_size
|
||||
self.out_channels = params.out_channels * params.patch_size * params.patch_size
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@@ -96,7 +99,9 @@ class Flux(nn.Module):
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
@@ -114,8 +119,19 @@ class Flux(nn.Module):
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@@ -127,7 +143,16 @@ class Flux(nn.Module):
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@@ -141,9 +166,9 @@ class Flux(nn.Module):
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
@@ -151,10 +176,10 @@ class Flux(nn.Module):
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options)
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
|
||||
25
comfy/ldm/flux/redux.py
Normal file
25
comfy/ldm/flux/redux.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
class ReduxImageEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
redux_dim: int = 1152,
|
||||
txt_in_features: int = 4096,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.redux_dim = redux_dim
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
|
||||
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
|
||||
|
||||
def forward(self, sigclip_embeds) -> torch.Tensor:
|
||||
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
|
||||
return projected_x
|
||||
@@ -494,8 +494,9 @@ class AsymmDiTJoint(nn.Module):
|
||||
packed_indices: Dict[str, torch.Tensor] = None,
|
||||
rope_cos: torch.Tensor = None,
|
||||
rope_sin: torch.Tensor = None,
|
||||
control=None, **kwargs
|
||||
control=None, transformer_options={}, **kwargs
|
||||
):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
y_feat = context
|
||||
y_mask = attention_mask
|
||||
sigma = timestep
|
||||
@@ -515,15 +516,32 @@ class AsymmDiTJoint(nn.Module):
|
||||
)
|
||||
del y_mask
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
x, y_feat = block(
|
||||
x,
|
||||
c,
|
||||
y_feat,
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
crop_y=num_tokens,
|
||||
) # (B, M, D), (B, L, D)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(
|
||||
args["img"],
|
||||
args["vec"],
|
||||
args["txt"],
|
||||
rope_cos=args["rope_cos"],
|
||||
rope_sin=args["rope_sin"],
|
||||
crop_y=args["num_tokens"]
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
||||
y_feat = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
x, y_feat = block(
|
||||
x,
|
||||
c,
|
||||
y_feat,
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
crop_y=num_tokens,
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||
|
||||
@@ -287,7 +287,7 @@ class HunYuanDiT(nn.Module):
|
||||
style=None,
|
||||
return_dict=False,
|
||||
control=None,
|
||||
transformer_options=None,
|
||||
transformer_options={},
|
||||
):
|
||||
"""
|
||||
Forward pass of the encoder.
|
||||
@@ -315,8 +315,7 @@ class HunYuanDiT(nn.Module):
|
||||
return_dict: bool
|
||||
Whether to return a dictionary.
|
||||
"""
|
||||
#import pdb
|
||||
#pdb.set_trace()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
encoder_hidden_states = context
|
||||
text_states = encoder_hidden_states # 2,77,1024
|
||||
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||
@@ -364,6 +363,8 @@ class HunYuanDiT(nn.Module):
|
||||
# Concatenate all extra vectors
|
||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
controls = None
|
||||
if control:
|
||||
controls = control.get("output", None)
|
||||
@@ -375,9 +376,20 @@ class HunYuanDiT(nn.Module):
|
||||
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
||||
else:
|
||||
skip = skips.pop()
|
||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||
else:
|
||||
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
|
||||
skip = None
|
||||
|
||||
if ("double_block", layer) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||
|
||||
|
||||
if layer < (self.depth // 2 - 1):
|
||||
skips.append(x)
|
||||
|
||||
514
comfy/ldm/lightricks/model.py
Normal file
514
comfy/ldm/lightricks/model.py
Normal file
@@ -0,0 +1,514 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.ldm.modules.attention
|
||||
from comfy.ldm.genmo.joint_model.layers import RMSNorm
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
Args
|
||||
timesteps (torch.Tensor):
|
||||
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
||||
embedding_dim (int):
|
||||
the dimension of the output.
|
||||
flip_sin_to_cos (bool):
|
||||
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
||||
downscale_freq_shift (float):
|
||||
Controls the delta between frequencies between dimensions
|
||||
scale (float):
|
||||
Scaling factor applied to the embeddings.
|
||||
max_period (int):
|
||||
Controls the maximum frequency of the embeddings
|
||||
Returns
|
||||
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
dtype=None, device=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
# else:
|
||||
# self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
"""
|
||||
For PixArt-Alpha.
|
||||
|
||||
Reference:
|
||||
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.outdim = size_emb_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
|
||||
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# No modulation happening here.
|
||||
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
|
||||
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
||||
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
"""
|
||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = hidden_size
|
||||
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act_1 = nn.SiLU()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GELU_approx(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
||||
cos_freqs = freqs_cis[0]
|
||||
sin_freqs = freqs_cis[1]
|
||||
|
||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||
|
||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, mask=None, pe=None):
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
|
||||
return x
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
fractional_positions = torch.stack(
|
||||
[
|
||||
indices_grid[:, i] / max_pos[i]
|
||||
for i in range(3)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return fractional_positions
|
||||
|
||||
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
dtype = torch.float32 #self.dtype
|
||||
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
|
||||
start = 1
|
||||
end = theta
|
||||
device = fractional_positions.device
|
||||
|
||||
indices = theta ** (
|
||||
torch.linspace(
|
||||
math.log(start, theta),
|
||||
math.log(end, theta),
|
||||
dim // 6,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
indices = indices.to(dtype=dtype)
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
|
||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
if dim % 6 != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
||||
|
||||
|
||||
class LTXVModel(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
|
||||
caption_channels=4096,
|
||||
num_layers=28,
|
||||
|
||||
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
# attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
||||
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
indices_grid = self.patchifier.get_grid(
|
||||
orig_num_frames=x.shape[2],
|
||||
orig_height=x.shape[3],
|
||||
orig_width=x.shape[4],
|
||||
batch_size=x.shape[0],
|
||||
scale_grid=((1 / frame_rate) * 8, 32, 32),
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
|
||||
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
|
||||
ts *= input_ts
|
||||
ts[:, :, 0] = 0.0
|
||||
timestep = self.patchifier.patchify(ts)
|
||||
input_x = x.clone()
|
||||
x[:, :, 0] = guiding_latent[:, :, 0]
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
|
||||
x = self.patchifier.patchify(x)
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
|
||||
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
|
||||
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
|
||||
|
||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=x.dtype,
|
||||
)
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(
|
||||
batch_size, -1, embedded_timestep.shape[-1]
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = x.shape[0]
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(
|
||||
batch_size, -1, x.shape[-1]
|
||||
)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
x,
|
||||
context=context,
|
||||
attention_mask=attention_mask,
|
||||
timestep=timestep,
|
||||
pe=pe
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
scale_shift_values = (
|
||||
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||
)
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = x * (1 + scale) + shift
|
||||
x = self.proj_out(x)
|
||||
|
||||
x = self.patchifier.unpatchify(
|
||||
latents=x,
|
||||
output_height=orig_shape[3],
|
||||
output_width=orig_shape[4],
|
||||
output_num_frames=orig_shape[2],
|
||||
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
|
||||
|
||||
# print("res", x)
|
||||
return x
|
||||
105
comfy/ldm/lightricks/symmetric_patchifier.py
Normal file
105
comfy/ldm/lightricks/symmetric_patchifier.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
elif dims_to_append == 0:
|
||||
return x
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
class Patchifier(ABC):
|
||||
def __init__(self, patch_size: int):
|
||||
super().__init__()
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
|
||||
@abstractmethod
|
||||
def patchify(
|
||||
self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
output_height: int,
|
||||
output_width: int,
|
||||
output_num_frames: int,
|
||||
out_channels: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def patch_size(self):
|
||||
return self._patch_size
|
||||
|
||||
def get_grid(
|
||||
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
|
||||
):
|
||||
f = orig_num_frames // self._patch_size[0]
|
||||
h = orig_height // self._patch_size[1]
|
||||
w = orig_width // self._patch_size[2]
|
||||
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
||||
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w)
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
if scale_grid is not None:
|
||||
for i in range(3):
|
||||
if isinstance(scale_grid[i], Tensor):
|
||||
scale = append_dims(scale_grid[i], grid.ndim - 1)
|
||||
else:
|
||||
scale = scale_grid[i]
|
||||
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
|
||||
|
||||
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
|
||||
return grid
|
||||
|
||||
|
||||
class SymmetricPatchifier(Patchifier):
|
||||
def patchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
latents = rearrange(
|
||||
latents,
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
p1=self._patch_size[0],
|
||||
p2=self._patch_size[1],
|
||||
p3=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
output_height: int,
|
||||
output_width: int,
|
||||
output_num_frames: int,
|
||||
out_channels: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
output_height = output_height // self._patch_size[1]
|
||||
output_width = output_width // self._patch_size[2]
|
||||
latents = rearrange(
|
||||
latents,
|
||||
"b (f h w) (c p q) -> b c f (h p) (w q) ",
|
||||
f=output_num_frames,
|
||||
h=output_height,
|
||||
w=output_width,
|
||||
p=self._patch_size[1],
|
||||
q=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
64
comfy/ldm/lightricks/vae/causal_conv3d.py
Normal file
64
comfy/ldm/lightricks/vae/causal_conv3d.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size: int = 3,
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
dilation = (dilation, 1, 1)
|
||||
|
||||
height_pad = kernel_size[1] // 2
|
||||
width_pad = kernel_size[2] // 2
|
||||
padding = (0, height_pad, width_pad)
|
||||
|
||||
self.conv = ops.Conv3d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
padding_mode="zeros",
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if causal:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, self.time_kernel_size - 1, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x), dim=2)
|
||||
else:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.conv.weight
|
||||
698
comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Normal file
698
comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Normal file
@@ -0,0 +1,698 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from functools import partial
|
||||
import math
|
||||
from einops import rearrange
|
||||
from typing import Any, Mapping, Optional, Tuple, Union, List
|
||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||
from .pixel_norm import PixelNorm
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""
|
||||
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
||||
|
||||
Args:
|
||||
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
||||
The number of dimensions to use in convolutions.
|
||||
in_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
||||
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
||||
base_channels (`int`, *optional*, defaults to 128):
|
||||
The number of output channels for the first convolutional layer.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
patch_size (`int`, *optional*, defaults to 1):
|
||||
The patch size to use. Should be a power of 2.
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
||||
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]] = 3,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: Union[int, Tuple[int]] = 1,
|
||||
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
||||
latent_log_var: str = "per_channel",
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.norm_layer = norm_layer
|
||||
self.latent_channels = out_channels
|
||||
self.latent_log_var = latent_log_var
|
||||
self.blocks_desc = blocks
|
||||
|
||||
in_channels = in_channels * patch_size**2
|
||||
output_channel = base_channels
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
for block_name, block_params in blocks:
|
||||
input_channel = output_channel
|
||||
if isinstance(block_params, int):
|
||||
block_params = {"num_layers": block_params}
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 1, 1),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(1, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_all_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown block: {block_name}")
|
||||
|
||||
self.down_blocks.append(block)
|
||||
|
||||
# out
|
||||
if norm_layer == "group_norm":
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.conv_norm_out = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
conv_out_channels = out_channels
|
||||
if latent_log_var == "per_channel":
|
||||
conv_out_channels *= 2
|
||||
elif latent_log_var == "uniform":
|
||||
conv_out_channels += 1
|
||||
elif latent_log_var != "none":
|
||||
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
checkpoint_fn = (
|
||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||
if self.gradient_checkpointing and self.training
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
sample = checkpoint_fn(down_block)(sample)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if self.latent_log_var == "uniform":
|
||||
last_channel = sample[:, -1:, ...]
|
||||
num_dims = sample.dim()
|
||||
|
||||
if num_dims == 4:
|
||||
# For shape (B, C, H, W)
|
||||
repeated_last_channel = last_channel.repeat(
|
||||
1, sample.shape[1] - 2, 1, 1
|
||||
)
|
||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||
elif num_dims == 5:
|
||||
# For shape (B, C, F, H, W)
|
||||
repeated_last_channel = last_channel.repeat(
|
||||
1, sample.shape[1] - 2, 1, 1, 1
|
||||
)
|
||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {sample.shape}")
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
||||
|
||||
Args:
|
||||
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
||||
The number of dimensions to use in convolutions.
|
||||
in_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
||||
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
||||
base_channels (`int`, *optional*, defaults to 128):
|
||||
The number of output channels for the first convolutional layer.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
patch_size (`int`, *optional*, defaults to 1):
|
||||
The patch size to use. Should be a power of 2.
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
causal (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use causal convolutions or not.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: int = 1,
|
||||
norm_layer: str = "group_norm",
|
||||
causal: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.layers_per_block = layers_per_block
|
||||
out_channels = out_channels * patch_size**2
|
||||
self.causal = causal
|
||||
self.blocks_desc = blocks
|
||||
|
||||
# Compute output channel to be product of all channel-multiplier blocks
|
||||
output_channel = base_channels
|
||||
for block_name, block_params in list(reversed(blocks)):
|
||||
block_params = block_params if isinstance(block_params, dict) else {}
|
||||
if block_name == "res_x_y":
|
||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims,
|
||||
in_channels,
|
||||
output_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
for block_name, block_params in list(reversed(blocks)):
|
||||
input_channel = output_channel
|
||||
if isinstance(block_params, int):
|
||||
block_params = {"num_layers": block_params}
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 2, 2),
|
||||
residual=block_params.get("residual", False),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown layer: {block_name}")
|
||||
|
||||
self.up_blocks.append(block)
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.conv_norm_out = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, out_channels, 3, padding=1, causal=True
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
# assert target_shape is not None, "target_shape must be provided"
|
||||
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
|
||||
checkpoint_fn = (
|
||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||
if self.gradient_checkpointing and self.training
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
"""
|
||||
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of input channels.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
||||
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
||||
resnet_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use in the group normalization layers of the resnet blocks.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
||||
in_channels, height, width)`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_groups: int = 32,
|
||||
norm_layer: str = "group_norm",
|
||||
):
|
||||
super().__init__()
|
||||
resnet_groups = (
|
||||
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
)
|
||||
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, causal: bool = True
|
||||
) -> torch.FloatTensor:
|
||||
for resnet in self.res_blocks:
|
||||
hidden_states = resnet(hidden_states, causal=causal)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
def __init__(self, dims, in_channels, stride, residual=False):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.out_channels = math.prod(stride) * in_channels
|
||||
self.conv = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causal=True,
|
||||
)
|
||||
self.residual = residual
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if self.residual:
|
||||
# Reshape and duplicate the input to match the output shape
|
||||
x_in = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1)
|
||||
if self.stride[0] == 2:
|
||||
x_in = x_in[:, :, 1:, :, :]
|
||||
x = self.conv(x, causal=causal)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2:
|
||||
x = x[:, :, 1:, :, :]
|
||||
if self.residual:
|
||||
x = x + x_in
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, "b c d h w -> b d h w c")
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, "b d h w c -> b c d h w")
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
r"""
|
||||
A Resnet block.
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
out_channels (`int`, *optional*, default to be `None`):
|
||||
The number of output channels for the first conv layer. If None, same as `in_channels`.
|
||||
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
||||
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
||||
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
norm_layer: str = "group_norm",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm1 = nn.GroupNorm(
|
||||
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.norm1 = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
self.conv1 = make_conv_nd(
|
||||
dims,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm2 = nn.GroupNorm(
|
||||
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.norm2 = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
|
||||
self.conv2 = make_conv_nd(
|
||||
dims,
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.conv_shortcut = (
|
||||
make_linear_nd(
|
||||
dims=dims, in_channels=in_channels, out_channels=out_channels
|
||||
)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
self.norm3 = (
|
||||
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_tensor: torch.FloatTensor,
|
||||
causal: bool = True,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states, causal=causal)
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = self.conv2(hidden_states, causal=causal)
|
||||
|
||||
input_tensor = self.norm3(input_tensor)
|
||||
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = input_tensor + hidden_states
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def patchify(x, patch_size_hw, patch_size_t=1):
|
||||
if patch_size_hw == 1 and patch_size_t == 1:
|
||||
return x
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
|
||||
)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
|
||||
p=patch_size_t,
|
||||
q=patch_size_hw,
|
||||
r=patch_size_hw,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x, patch_size_hw, patch_size_t=1):
|
||||
if patch_size_hw == 1 and patch_size_t == 1:
|
||||
return x
|
||||
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
|
||||
)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
|
||||
p=patch_size_t,
|
||||
q=patch_size_hw,
|
||||
r=patch_size_hw,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
class processor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
|
||||
self.register_buffer("channel", torch.empty(128))
|
||||
|
||||
def un_normalize(self, x):
|
||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
def normalize(self, x):
|
||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"blocks": [
|
||||
["res_x", 4],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x", 3],
|
||||
["res_x", 4],
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
}
|
||||
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
)
|
||||
|
||||
self.encoder = Encoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config.get("in_channels", 3),
|
||||
out_channels=config["latent_channels"],
|
||||
blocks=config.get("encoder_blocks", config.get("blocks")),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("blocks")),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def encode(self, x):
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x))
|
||||
|
||||
83
comfy/ldm/lightricks/vae/conv_nd_factory.py
Normal file
83
comfy/ldm/lightricks/vae/conv_nd_factory.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .dual_conv3d import DualConv3d
|
||||
from .causal_conv3d import CausalConv3d
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def make_conv_nd(
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
causal=False,
|
||||
):
|
||||
if dims == 2:
|
||||
return ops.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
elif dims == 3:
|
||||
if causal:
|
||||
return CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
return ops.Conv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
elif dims == (2, 1):
|
||||
return DualConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def make_linear_nd(
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
bias=True,
|
||||
):
|
||||
if dims == 2:
|
||||
return ops.Conv2d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||
)
|
||||
elif dims == 3 or dims == (2, 1):
|
||||
return ops.Conv3d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
195
comfy/ldm/lightricks/vae/dual_conv3d.py
Normal file
195
comfy/ldm/lightricks/vae/dual_conv3d.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class DualConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
):
|
||||
super(DualConv3d, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if kernel_size == (1, 1, 1):
|
||||
raise ValueError(
|
||||
"kernel_size must be greater than 1. Use make_linear_nd instead."
|
||||
)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding, padding)
|
||||
if isinstance(dilation, int):
|
||||
dilation = (dilation, dilation, dilation)
|
||||
|
||||
# Set parameters for convolutions
|
||||
self.groups = groups
|
||||
self.bias = bias
|
||||
|
||||
# Define the size of the channels after the first convolution
|
||||
intermediate_channels = (
|
||||
out_channels if in_channels < out_channels else in_channels
|
||||
)
|
||||
|
||||
# Define parameters for the first convolution
|
||||
self.weight1 = nn.Parameter(
|
||||
torch.Tensor(
|
||||
intermediate_channels,
|
||||
in_channels // groups,
|
||||
1,
|
||||
kernel_size[1],
|
||||
kernel_size[2],
|
||||
)
|
||||
)
|
||||
self.stride1 = (1, stride[1], stride[2])
|
||||
self.padding1 = (0, padding[1], padding[2])
|
||||
self.dilation1 = (1, dilation[1], dilation[2])
|
||||
if bias:
|
||||
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
|
||||
else:
|
||||
self.register_parameter("bias1", None)
|
||||
|
||||
# Define parameters for the second convolution
|
||||
self.weight2 = nn.Parameter(
|
||||
torch.Tensor(
|
||||
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
|
||||
)
|
||||
)
|
||||
self.stride2 = (stride[0], 1, 1)
|
||||
self.padding2 = (padding[0], 0, 0)
|
||||
self.dilation2 = (dilation[0], 1, 1)
|
||||
if bias:
|
||||
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter("bias2", None)
|
||||
|
||||
# Initialize weights and biases
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
|
||||
if self.bias:
|
||||
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
|
||||
bound1 = 1 / math.sqrt(fan_in1)
|
||||
nn.init.uniform_(self.bias1, -bound1, bound1)
|
||||
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
|
||||
bound2 = 1 / math.sqrt(fan_in2)
|
||||
nn.init.uniform_(self.bias2, -bound2, bound2)
|
||||
|
||||
def forward(self, x, use_conv3d=False, skip_time_conv=False):
|
||||
if use_conv3d:
|
||||
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
|
||||
else:
|
||||
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
|
||||
|
||||
def forward_with_3d(self, x, skip_time_conv):
|
||||
# First convolution
|
||||
x = F.conv3d(
|
||||
x,
|
||||
self.weight1,
|
||||
self.bias1,
|
||||
self.stride1,
|
||||
self.padding1,
|
||||
self.dilation1,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
if skip_time_conv:
|
||||
return x
|
||||
|
||||
# Second convolution
|
||||
x = F.conv3d(
|
||||
x,
|
||||
self.weight2,
|
||||
self.bias2,
|
||||
self.stride2,
|
||||
self.padding2,
|
||||
self.dilation2,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def forward_with_2d(self, x, skip_time_conv):
|
||||
b, c, d, h, w = x.shape
|
||||
|
||||
# First 2D convolution
|
||||
x = rearrange(x, "b c d h w -> (b d) c h w")
|
||||
# Squeeze the depth dimension out of weight1 since it's 1
|
||||
weight1 = self.weight1.squeeze(2)
|
||||
# Select stride, padding, and dilation for the 2D convolution
|
||||
stride1 = (self.stride1[1], self.stride1[2])
|
||||
padding1 = (self.padding1[1], self.padding1[2])
|
||||
dilation1 = (self.dilation1[1], self.dilation1[2])
|
||||
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
|
||||
|
||||
_, _, h, w = x.shape
|
||||
|
||||
if skip_time_conv:
|
||||
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
|
||||
return x
|
||||
|
||||
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
|
||||
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
|
||||
|
||||
# Reshape weight2 to match the expected dimensions for conv1d
|
||||
weight2 = self.weight2.squeeze(-1).squeeze(-1)
|
||||
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
|
||||
stride2 = self.stride2[0]
|
||||
padding2 = self.padding2[0]
|
||||
dilation2 = self.dilation2[0]
|
||||
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
|
||||
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
||||
|
||||
return x
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.weight2
|
||||
|
||||
|
||||
def test_dual_conv3d_consistency():
|
||||
# Initialize parameters
|
||||
in_channels = 3
|
||||
out_channels = 5
|
||||
kernel_size = (3, 3, 3)
|
||||
stride = (2, 2, 2)
|
||||
padding = (1, 1, 1)
|
||||
|
||||
# Create an instance of the DualConv3d class
|
||||
dual_conv3d = DualConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
# Example input tensor
|
||||
test_input = torch.randn(1, 3, 10, 10, 10)
|
||||
|
||||
# Perform forward passes with both 3D and 2D settings
|
||||
output_conv3d = dual_conv3d(test_input, use_conv3d=True)
|
||||
output_2d = dual_conv3d(test_input, use_conv3d=False)
|
||||
|
||||
# Assert that the outputs from both methods are sufficiently close
|
||||
assert torch.allclose(
|
||||
output_conv3d, output_2d, atol=1e-6
|
||||
), "Outputs are not consistent between 3D and 2D convolutions."
|
||||
12
comfy/ldm/lightricks/vae/pixel_norm.py
Normal file
12
comfy/ldm/lightricks/vae/pixel_norm.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
def __init__(self, dim=1, eps=1e-8):
|
||||
super(PixelNorm, self).__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
|
||||
@@ -299,7 +299,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
if len(mask.shape) == 2:
|
||||
s1 += mask[i:end]
|
||||
else:
|
||||
s1 += mask[:, i:end]
|
||||
if mask.shape[1] == 1:
|
||||
s1 += mask
|
||||
else:
|
||||
s1 += mask[:, i:end]
|
||||
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
@@ -372,10 +375,10 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
pad = 8 - q.shape[1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[:, :, :mask.shape[-1]] = mask
|
||||
mask = mask_out[:, :, :mask.shape[-1]]
|
||||
pad = 8 - mask.shape[-1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[..., :mask.shape[-1]] = mask
|
||||
mask = mask_out[..., :mask.shape[-1]]
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from .util import (
|
||||
)
|
||||
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
||||
from comfy.ldm.util import exists
|
||||
import comfy.patcher_extension
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
@@ -47,6 +48,15 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||
found_patched = False
|
||||
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||
if isinstance(layer, class_type):
|
||||
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||
found_patched = True
|
||||
break
|
||||
if found_patched:
|
||||
continue
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@@ -819,6 +829,13 @@ class UNetModel(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timesteps, context, y, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
|
||||
@@ -234,6 +234,8 @@ def efficient_dot_product_attention(
|
||||
def get_mask_chunk(chunk_idx: int) -> Tensor:
|
||||
if mask is None:
|
||||
return None
|
||||
if mask.shape[1] == 1:
|
||||
return mask
|
||||
chunk = min(query_chunk_size, q_tokens)
|
||||
return mask[:,chunk_idx:chunk_idx + chunk]
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ LORA_CLIP_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def load_lora(lora, to_load):
|
||||
def load_lora(lora, to_load, log_missing=True):
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
@@ -49,10 +49,20 @@ def load_lora(lora, to_load):
|
||||
dora_scale = lora[dora_scale_name]
|
||||
loaded_keys.add(dora_scale_name)
|
||||
|
||||
reshape_name = "{}.reshape_weight".format(x)
|
||||
reshape = None
|
||||
if reshape_name in lora.keys():
|
||||
try:
|
||||
reshape = lora[reshape_name].tolist()
|
||||
loaded_keys.add(reshape_name)
|
||||
except:
|
||||
pass
|
||||
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||
mochi_lora = "{}.lora_B".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
@@ -72,6 +82,10 @@ def load_lora(lora, to_load):
|
||||
A_name = diffusers3_lora
|
||||
B_name = "{}.lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif mochi_lora in lora.keys():
|
||||
A_name = mochi_lora
|
||||
B_name = "{}.lora_A".format(x)
|
||||
mid_name = None
|
||||
elif transformers_lora in lora.keys():
|
||||
A_name = transformers_lora
|
||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||
@@ -82,7 +96,7 @@ def load_lora(lora, to_load):
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
@@ -193,9 +207,16 @@ def load_lora(lora, to_load):
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
logging.warning("lora key not loaded: {}".format(x))
|
||||
set_weight_name = "{}.set_weight".format(x)
|
||||
set_weight = lora.get(set_weight_name, None)
|
||||
if set_weight is not None:
|
||||
patch_dict[to_load[x]] = ("set", (set_weight,))
|
||||
loaded_keys.add(set_weight_name)
|
||||
|
||||
if log_missing:
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
logging.warning("lora key not loaded: {}".format(x))
|
||||
|
||||
return patch_dict
|
||||
|
||||
@@ -282,11 +303,14 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
sdk = sd.keys()
|
||||
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
else:
|
||||
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
|
||||
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
||||
for k in diffusers_keys:
|
||||
@@ -344,6 +368,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
|
||||
if isinstance(model, comfy.model_base.GenmoMochi):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official Mochi lora format
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
@@ -400,7 +430,7 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
||||
|
||||
return padded_tensor
|
||||
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
v = p[1]
|
||||
@@ -440,10 +470,22 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||
else:
|
||||
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||
elif patch_type == "set":
|
||||
weight.copy_(v[0])
|
||||
elif patch_type == "model_as_lora":
|
||||
target_weight: torch.Tensor = v[0]
|
||||
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
||||
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
||||
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
||||
elif patch_type == "lora": #lora/locon
|
||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||
dora_scale = v[4]
|
||||
reshape = v[5]
|
||||
|
||||
if reshape is not None:
|
||||
weight = pad_tensor_to_shape(weight, reshape)
|
||||
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
|
||||
17
comfy/lora_convert.py
Normal file
17
comfy/lora_convert.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
|
||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
||||
sd_out[k_to] = sd[k]
|
||||
|
||||
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||
return sd_out
|
||||
|
||||
|
||||
def convert_lora(sd):
|
||||
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||
return convert_lora_bfl_control(sd)
|
||||
return sd
|
||||
@@ -30,14 +30,19 @@ import comfy.ldm.hydit.models
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.ldm.flux.model
|
||||
import comfy.ldm.lightricks.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.conds
|
||||
import comfy.ops
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
import comfy.latent_formats
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
@@ -94,6 +99,7 @@ class BaseModel(torch.nn.Module):
|
||||
self.model_config = model_config
|
||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||
self.device = device
|
||||
self.current_patcher: 'ModelPatcher' = None
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
@@ -119,6 +125,13 @@ class BaseModel(torch.nn.Module):
|
||||
self.memory_usage_factor = model_config.memory_usage_factor
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._apply_model,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, transformer_options)
|
||||
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
if c_concat is not None:
|
||||
@@ -153,8 +166,7 @@ class BaseModel(torch.nn.Module):
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
def concat_cond(self, **kwargs):
|
||||
if len(self.concat_keys) > 0:
|
||||
cond_concat = []
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
@@ -193,7 +205,14 @@ class BaseModel(torch.nn.Module):
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
||||
return data
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
concat_cond = self.concat_cond(**kwargs)
|
||||
if concat_cond is not None:
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_cond)
|
||||
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
@@ -523,9 +542,7 @@ class SD_X4Upscaler(BaseModel):
|
||||
return out
|
||||
|
||||
class IP2P:
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
@@ -537,18 +554,15 @@ class IP2P:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
return self.process_ip2p_image_in(image)
|
||||
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = comfy.conds.CONDRegular(adm)
|
||||
return out
|
||||
|
||||
class SD15_instructpix2pix(IP2P, BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.process_ip2p_image_in = lambda image: image
|
||||
|
||||
|
||||
class SDXL_instructpix2pix(IP2P, SDXL):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@@ -709,6 +723,44 @@ class Flux(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
try:
|
||||
#Handle Flux control loras dynamically changing the img_in weight.
|
||||
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
|
||||
except:
|
||||
#Some cases like tensorrt might not have the weights accessible
|
||||
num_channels = self.model_config.unet_config["in_channels"]
|
||||
|
||||
out_channels = self.model_config.unet_config["out_channels"]
|
||||
|
||||
if num_channels <= out_channels:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
image = self.process_latent_in(image)
|
||||
if num_channels <= out_channels * 2:
|
||||
return image
|
||||
|
||||
#inpaint model
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.ones_like(noise)[:, :1]
|
||||
|
||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||
print(mask.shape)
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
|
||||
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
@@ -734,3 +786,23 @@ class GenmoMochi(BaseModel):
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class LTXV(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
guiding_latent = kwargs.get("guiding_latent", None)
|
||||
if guiding_latent is not None:
|
||||
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
return out
|
||||
|
||||
@@ -137,6 +137,12 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["in_channels"] = 16
|
||||
patch_size = 2
|
||||
dit_config["patch_size"] = patch_size
|
||||
in_key = "{}img_in.weight".format(key_prefix)
|
||||
if in_key in state_dict_keys:
|
||||
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["vec_in_dim"] = 768
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
@@ -177,6 +183,10 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
return dit_config
|
||||
|
||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ltxv"
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
@@ -321,8 +331,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
if model_config is None and use_base_if_no_match:
|
||||
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||
|
||||
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
|
||||
if scaled_fp8_weight is not None:
|
||||
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
|
||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||
if model_config.scaled_fp8 == torch.float32:
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
|
||||
@@ -23,6 +23,8 @@ from comfy.cli_args import args
|
||||
import torch
|
||||
import sys
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@@ -287,11 +289,27 @@ def module_size(module):
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
self.weights_loaded = False
|
||||
self.real_model = None
|
||||
self.currently_used = True
|
||||
self.model_finalizer = None
|
||||
self._patcher_finalizer = None
|
||||
|
||||
def _set_model(self, model):
|
||||
self._model = weakref.ref(model)
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
|
||||
|
||||
def _switch_parent(self):
|
||||
model = self._parent_model()
|
||||
if model is not None:
|
||||
self._set_model(model)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self._model()
|
||||
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
@@ -306,32 +324,23 @@ class LoadedModel:
|
||||
return self.model_memory()
|
||||
|
||||
def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
|
||||
patch_model_to = self.device
|
||||
|
||||
self.model.model_patches_to(self.device)
|
||||
self.model.model_patches_to(self.model.model_dtype())
|
||||
|
||||
load_weights = not self.weights_loaded
|
||||
# if self.model.loaded_size() > 0:
|
||||
use_more_vram = lowvram_model_memory
|
||||
if use_more_vram == 0:
|
||||
use_more_vram = 1e32
|
||||
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
|
||||
real_model = self.model.model
|
||||
|
||||
if self.model.loaded_size() > 0:
|
||||
use_more_vram = lowvram_model_memory
|
||||
if use_more_vram == 0:
|
||||
use_more_vram = 1e32
|
||||
self.model_use_more_vram(use_more_vram)
|
||||
else:
|
||||
try:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
|
||||
except Exception as e:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
raise e
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
|
||||
with torch.no_grad():
|
||||
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||
real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||
|
||||
self.weights_loaded = True
|
||||
return self.real_model
|
||||
self.real_model = weakref.ref(real_model)
|
||||
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
||||
return real_model
|
||||
|
||||
def should_reload_model(self, force_patch_weights=False):
|
||||
if force_patch_weights and self.model.lowvram_patch_counter() > 0:
|
||||
@@ -344,18 +353,26 @@ class LoadedModel:
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||
self.model.model_patches_to(self.model.offload_device)
|
||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||
self.model.detach(unpatch_weights)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
return True
|
||||
|
||||
def model_use_more_vram(self, extra_memory):
|
||||
return self.model.partially_load(self.device, extra_memory)
|
||||
def model_use_more_vram(self, extra_memory, force_patch_weights=False):
|
||||
return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.model is other.model
|
||||
|
||||
def __del__(self):
|
||||
if self._patcher_finalizer is not None:
|
||||
self._patcher_finalizer.detach()
|
||||
|
||||
def is_dead(self):
|
||||
return self.real_model() is not None and self.model is None
|
||||
|
||||
|
||||
def use_more_memory(extra_memory, loaded_models, device):
|
||||
for m in loaded_models:
|
||||
if m.device == device:
|
||||
@@ -386,38 +403,8 @@ def extra_reserved_memory():
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||
to_unload = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
|
||||
if len(to_unload) == 0:
|
||||
return True
|
||||
|
||||
same_weights = 0
|
||||
for i in to_unload:
|
||||
if model.clone_has_same_weights(current_loaded_models[i].model):
|
||||
same_weights += 1
|
||||
|
||||
if same_weights == len(to_unload):
|
||||
unload_weight = False
|
||||
else:
|
||||
unload_weight = True
|
||||
|
||||
if not force_unload:
|
||||
if unload_weights_only and unload_weight == False:
|
||||
return None
|
||||
else:
|
||||
unload_weight = True
|
||||
|
||||
for i in to_unload:
|
||||
logging.debug("unload clone {} {}".format(i, unload_weight))
|
||||
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
|
||||
|
||||
return unload_weight
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@@ -425,7 +412,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded:
|
||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
@@ -454,6 +441,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
return unloaded_models
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
@@ -466,11 +454,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
models = set(models)
|
||||
|
||||
models_to_load = []
|
||||
models_already_loaded = []
|
||||
|
||||
for x in models:
|
||||
loaded_model = LoadedModel(x)
|
||||
loaded = None
|
||||
|
||||
try:
|
||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||
except:
|
||||
@@ -478,51 +464,35 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
if loaded_model_index is not None:
|
||||
loaded = current_loaded_models[loaded_model_index]
|
||||
if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic
|
||||
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
|
||||
loaded = None
|
||||
else:
|
||||
loaded.currently_used = True
|
||||
models_already_loaded.append(loaded)
|
||||
|
||||
if loaded is None:
|
||||
loaded.currently_used = True
|
||||
models_to_load.append(loaded)
|
||||
else:
|
||||
if hasattr(x, "model"):
|
||||
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
||||
models_to_load.append(loaded_model)
|
||||
|
||||
if len(models_to_load) == 0:
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
|
||||
free_mem = get_free_memory(d)
|
||||
if free_mem < minimum_memory_required:
|
||||
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
|
||||
models_to_load = free_memory(minimum_memory_required, d)
|
||||
logging.info("{} models unloaded.".format(len(models_to_load)))
|
||||
else:
|
||||
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||
if len(models_to_load) == 0:
|
||||
return
|
||||
|
||||
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||
for loaded_model in models_to_load:
|
||||
to_unload = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if loaded_model.model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
for i in to_unload:
|
||||
current_loaded_models.pop(i).model.detach(unpatch_all=False)
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for loaded_model in models_already_loaded:
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
|
||||
if weights_unloaded is not None:
|
||||
loaded_model.weights_loaded = not weights_unloaded
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem < minimum_memory_required:
|
||||
models_l = free_memory(minimum_memory_required, device)
|
||||
logging.info("{} models unloaded.".format(len(models_l)))
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
@@ -544,17 +514,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
|
||||
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
free_mem = get_free_memory(d)
|
||||
if free_mem > minimum_memory_required:
|
||||
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||
return
|
||||
|
||||
|
||||
def load_model_gpu(model):
|
||||
return load_models_gpu([model])
|
||||
|
||||
@@ -568,21 +529,35 @@ def loaded_models(only_currently_used=False):
|
||||
output.append(m.model)
|
||||
return output
|
||||
|
||||
def cleanup_models(keep_clone_weights_loaded=False):
|
||||
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
|
||||
do_gc = True
|
||||
break
|
||||
|
||||
if do_gc:
|
||||
gc.collect()
|
||||
soft_empty_cache()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||
|
||||
|
||||
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
#TODO: very fragile function needs improvement
|
||||
num_refs = sys.getrefcount(current_loaded_models[i].model)
|
||||
if num_refs <= 2:
|
||||
if not keep_clone_weights_loaded:
|
||||
to_delete = [i] + to_delete
|
||||
#TODO: find a less fragile way to do this.
|
||||
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
|
||||
to_delete = [i] + to_delete
|
||||
if current_loaded_models[i].real_model() is None:
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
for i in to_delete:
|
||||
x = current_loaded_models.pop(i)
|
||||
x.model_unload()
|
||||
del x
|
||||
|
||||
def dtype_size(dtype):
|
||||
@@ -628,6 +603,10 @@ def maximum_vram_for_weights(device=None):
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if model_params < 0:
|
||||
model_params = 1000000000000000000000
|
||||
if args.fp32_unet:
|
||||
return torch.float32
|
||||
if args.fp64_unet:
|
||||
return torch.float64
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if args.fp16_unet:
|
||||
@@ -674,7 +653,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
|
||||
# None means no manual cast
|
||||
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if weight_dtype == torch.float32:
|
||||
if weight_dtype == torch.float32 or weight_dtype == torch.float64:
|
||||
return None
|
||||
|
||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,25 @@ import torch
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
import math
|
||||
|
||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||
|
||||
class EPS:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
@@ -48,7 +67,7 @@ class CONST:
|
||||
return latent / (1.0 - sigma)
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
def __init__(self, model_config=None, zsnr=None):
|
||||
super().__init__()
|
||||
|
||||
if model_config is not None:
|
||||
@@ -61,11 +80,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
linear_end = sampling_settings.get("linear_end", 0.012)
|
||||
timesteps = sampling_settings.get("timesteps", 1000)
|
||||
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
||||
if zsnr is None:
|
||||
zsnr = sampling_settings.get("zsnr", False)
|
||||
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3, zsnr=zsnr)
|
||||
self.sigma_data = 1.0
|
||||
|
||||
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zsnr=False):
|
||||
if given_betas is not None:
|
||||
betas = given_betas
|
||||
else:
|
||||
@@ -83,6 +105,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
if zsnr:
|
||||
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
|
||||
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def set_sigmas(self, sigmas):
|
||||
|
||||
156
comfy/patcher_extension.py
Normal file
156
comfy/patcher_extension.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
from typing import Callable
|
||||
|
||||
class CallbacksMP:
|
||||
ON_CLONE = "on_clone"
|
||||
ON_LOAD = "on_load_after"
|
||||
ON_DETACH = "on_detach_after"
|
||||
ON_CLEANUP = "on_cleanup"
|
||||
ON_PRE_RUN = "on_pre_run"
|
||||
ON_PREPARE_STATE = "on_prepare_state"
|
||||
ON_APPLY_HOOKS = "on_apply_hooks"
|
||||
ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
|
||||
ON_INJECT_MODEL = "on_inject_model"
|
||||
ON_EJECT_MODEL = "on_eject_model"
|
||||
|
||||
# callbacks dict is in the format:
|
||||
# {"call_type": {"key": [Callable1, Callable2, ...]} }
|
||||
@classmethod
|
||||
def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]:
|
||||
return {}
|
||||
|
||||
def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False):
|
||||
add_callback_with_key(call_type, None, callback, transformer_options, is_model_options)
|
||||
|
||||
def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.setdefault("transformer_options", {})
|
||||
callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {})
|
||||
c = callbacks.setdefault(call_type, {}).setdefault(key, [])
|
||||
c.append(callback)
|
||||
|
||||
def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
c_list = []
|
||||
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
|
||||
c_list.extend(callbacks.get(call_type, {}).get(key, []))
|
||||
return c_list
|
||||
|
||||
def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
c_list = []
|
||||
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
|
||||
for c in callbacks.get(call_type, {}).values():
|
||||
c_list.extend(c)
|
||||
return c_list
|
||||
|
||||
class WrappersMP:
|
||||
OUTER_SAMPLE = "outer_sample"
|
||||
SAMPLER_SAMPLE = "sampler_sample"
|
||||
CALC_COND_BATCH = "calc_cond_batch"
|
||||
APPLY_MODEL = "apply_model"
|
||||
DIFFUSION_MODEL = "diffusion_model"
|
||||
|
||||
# wrappers dict is in the format:
|
||||
# {"wrapper_type": {"key": [Callable1, Callable2, ...]} }
|
||||
@classmethod
|
||||
def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]:
|
||||
return {}
|
||||
|
||||
def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
|
||||
add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options)
|
||||
|
||||
def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.setdefault("transformer_options", {})
|
||||
wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {})
|
||||
w = wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
||||
w.append(wrapper)
|
||||
|
||||
def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
w_list = []
|
||||
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
|
||||
w_list.extend(wrappers.get(wrapper_type, {}).get(key, []))
|
||||
return w_list
|
||||
|
||||
def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False):
|
||||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
w_list = []
|
||||
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
|
||||
for w in wrappers.get(wrapper_type, {}).values():
|
||||
w_list.extend(w)
|
||||
return w_list
|
||||
|
||||
class WrapperExecutor:
|
||||
"""Handles call stack of wrappers around a function in an ordered manner."""
|
||||
def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int):
|
||||
# NOTE: class_obj exists so that wrappers surrounding a class method can access
|
||||
# the class instance at runtime via executor.class_obj
|
||||
self.original = original
|
||||
self.class_obj = class_obj
|
||||
self.wrappers = wrappers.copy()
|
||||
self.idx = idx
|
||||
self.is_last = idx == len(wrappers)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Calls the next wrapper or original function, whichever is appropriate."""
|
||||
new_executor = self._create_next_executor()
|
||||
return new_executor.execute(*args, **kwargs)
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
|
||||
args = list(args)
|
||||
kwargs = dict(kwargs)
|
||||
if self.is_last:
|
||||
return self.original(*args, **kwargs)
|
||||
return self.wrappers[self.idx](self, *args, **kwargs)
|
||||
|
||||
def _create_next_executor(self) -> 'WrapperExecutor':
|
||||
new_idx = self.idx + 1
|
||||
if new_idx > len(self.wrappers):
|
||||
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.")
|
||||
if self.class_obj is None:
|
||||
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
|
||||
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
|
||||
|
||||
@classmethod
|
||||
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
|
||||
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
|
||||
|
||||
@classmethod
|
||||
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
|
||||
return cls(original, class_obj, wrappers, idx=idx)
|
||||
|
||||
class PatcherInjection:
|
||||
def __init__(self, inject: Callable, eject: Callable):
|
||||
self.inject = inject
|
||||
self.eject = eject
|
||||
|
||||
def copy_nested_dicts(input_dict: dict):
|
||||
new_dict = input_dict.copy()
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, dict):
|
||||
new_dict[key] = copy_nested_dicts(value)
|
||||
elif isinstance(value, list):
|
||||
new_dict[key] = value.copy()
|
||||
return new_dict
|
||||
|
||||
def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
|
||||
if copy_dict1:
|
||||
merged_dict = copy_nested_dicts(dict1)
|
||||
else:
|
||||
merged_dict = dict1
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, dict):
|
||||
curr_value = merged_dict.setdefault(key, {})
|
||||
merged_dict[key] = merge_nested_dicts(value, curr_value)
|
||||
elif isinstance(value, list):
|
||||
merged_dict.setdefault(key, []).extend(value)
|
||||
else:
|
||||
merged_dict[key] = value
|
||||
return merged_dict
|
||||
@@ -1,22 +1,61 @@
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.utils
|
||||
import comfy.hooks
|
||||
import comfy.patcher_extension
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
"""ensures noise mask is of proper dimensions"""
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
||||
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
||||
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
return noise_mask
|
||||
return comfy.utils.reshape_mask(noise_mask, shape).to(device)
|
||||
|
||||
def get_models_from_cond(cond, model_type):
|
||||
models = []
|
||||
for c in cond:
|
||||
if model_type in c:
|
||||
models += [c[model_type]]
|
||||
if isinstance(c[model_type], list):
|
||||
models += c[model_type]
|
||||
else:
|
||||
models += [c[model_type]]
|
||||
return models
|
||||
|
||||
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
|
||||
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
|
||||
cnets: list[ControlBase] = []
|
||||
for c in cond:
|
||||
if 'hooks' in c:
|
||||
for hook in c['hooks'].hooks:
|
||||
hook: comfy.hooks.Hook
|
||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
if 'control' in c:
|
||||
cnets.append(c['control'])
|
||||
|
||||
def get_extra_hooks_from_cnet(cnet: ControlBase, _list: list):
|
||||
if cnet.extra_hooks is not None:
|
||||
_list.append(cnet.extra_hooks)
|
||||
if cnet.previous_controlnet is None:
|
||||
return _list
|
||||
return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list)
|
||||
|
||||
hooks_list = []
|
||||
cnets = set(cnets)
|
||||
for base_cnet in cnets:
|
||||
get_extra_hooks_from_cnet(base_cnet, hooks_list)
|
||||
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
|
||||
if extra_hooks is not None:
|
||||
for hook in extra_hooks.hooks:
|
||||
with_type = hooks_dict.setdefault(hook.hook_type, {})
|
||||
with_type[hook] = None
|
||||
|
||||
return hooks_dict
|
||||
|
||||
def convert_cond(cond):
|
||||
out = []
|
||||
for c in cond:
|
||||
@@ -26,17 +65,22 @@ def convert_cond(cond):
|
||||
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
|
||||
temp["cross_attn"] = c[0]
|
||||
temp["model_conds"] = model_conds
|
||||
temp["uuid"] = uuid.uuid4()
|
||||
out.append(temp)
|
||||
return out
|
||||
|
||||
def get_additional_models(conds, dtype):
|
||||
"""loads additional models in conditioning"""
|
||||
cnets = []
|
||||
cnets: list[ControlBase] = []
|
||||
gligen = []
|
||||
add_models = []
|
||||
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
|
||||
|
||||
for k in conds:
|
||||
cnets += get_models_from_cond(conds[k], "control")
|
||||
gligen += get_models_from_cond(conds[k], "gligen")
|
||||
add_models += get_models_from_cond(conds[k], "additional_models")
|
||||
get_hooks_from_cond(conds[k], hooks)
|
||||
|
||||
control_nets = set(cnets)
|
||||
|
||||
@@ -47,7 +91,9 @@ def get_additional_models(conds, dtype):
|
||||
inference_memory += m.inference_memory_requirements(dtype)
|
||||
|
||||
gligen = [x[1] for x in gligen]
|
||||
models = control_models + gligen
|
||||
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
|
||||
models = control_models + gligen + add_models + hook_models
|
||||
|
||||
return models, inference_memory
|
||||
|
||||
def cleanup_additional_models(models):
|
||||
@@ -57,10 +103,11 @@ def cleanup_additional_models(models):
|
||||
m.cleanup()
|
||||
|
||||
|
||||
def prepare_sampling(model, noise_shape, conds):
|
||||
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
|
||||
device = model.load_device
|
||||
real_model = None
|
||||
real_model: 'BaseModel' = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
|
||||
@@ -76,3 +123,14 @@ def cleanup_models(conds, models):
|
||||
control_cleanup += get_models_from_cond(conds[k], "control")
|
||||
|
||||
cleanup_additional_models(set(control_cleanup))
|
||||
|
||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
# check for hooks in conds - if not registered, see if can be applied
|
||||
hooks = {}
|
||||
for k in conds:
|
||||
get_hooks_from_cond(conds[k], hooks)
|
||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
||||
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
||||
# register hooks on model/model_options
|
||||
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
from __future__ import annotations
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
import torch
|
||||
import collections
|
||||
from comfy import model_management
|
||||
import math
|
||||
import logging
|
||||
import comfy.samplers
|
||||
import comfy.sampler_helpers
|
||||
import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
@@ -70,6 +80,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
for c in model_conds:
|
||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||
|
||||
hooks = conds.get('hooks', None)
|
||||
control = conds.get('control', None)
|
||||
|
||||
patches = None
|
||||
@@ -85,8 +96,8 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
|
||||
patches['middle_patch'] = [gligen_patch]
|
||||
|
||||
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
|
||||
return cond_obj(input_x, mult, conditioning, area, control, patches)
|
||||
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches', 'uuid', 'hooks'])
|
||||
return cond_obj(input_x, mult, conditioning, area, control, patches, conds['uuid'], hooks)
|
||||
|
||||
def cond_equal_size(c1, c2):
|
||||
if c1 is c2:
|
||||
@@ -138,110 +149,184 @@ def cond_cat(c_list):
|
||||
|
||||
return out
|
||||
|
||||
def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
||||
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep):
|
||||
# need to figure out remaining unmasked area for conds
|
||||
default_mults = []
|
||||
for _ in default_conds:
|
||||
default_mults.append(torch.ones_like(x_in))
|
||||
# look through each finalized cond in hooked_to_run for 'mult' and subtract it from each cond
|
||||
for lora_hooks, to_run in hooked_to_run.items():
|
||||
for cond_obj, i in to_run:
|
||||
# if no default_cond for cond_type, do nothing
|
||||
if len(default_conds[i]) == 0:
|
||||
continue
|
||||
area: list[int] = cond_obj.area
|
||||
if area is not None:
|
||||
curr_default_mult: torch.Tensor = default_mults[i]
|
||||
dims = len(area) // 2
|
||||
for i in range(dims):
|
||||
curr_default_mult = curr_default_mult.narrow(i + 2, area[i + dims], area[i])
|
||||
curr_default_mult -= cond_obj.mult
|
||||
else:
|
||||
default_mults[i] -= cond_obj.mult
|
||||
# for each default_mult, ReLU to make negatives=0, and then check for any nonzeros
|
||||
for i, mult in enumerate(default_mults):
|
||||
# if no default_cond for cond type, do nothing
|
||||
if len(default_conds[i]) == 0:
|
||||
continue
|
||||
torch.nn.functional.relu(mult, inplace=True)
|
||||
# if mult is all zeros, then don't add default_cond
|
||||
if torch.max(mult) == 0.0:
|
||||
continue
|
||||
|
||||
cond = default_conds[i]
|
||||
for x in cond:
|
||||
# do get_area_and_mult to get all the expected values
|
||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
# replace p's mult with calculated mult
|
||||
p = p._replace(mult=mult)
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
|
||||
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_calc_cond_batch,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
to_run = []
|
||||
# separate conds by matching hooks
|
||||
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||
default_conds = []
|
||||
has_default_conds = False
|
||||
|
||||
for i in range(len(conds)):
|
||||
out_conds.append(torch.zeros_like(x_in))
|
||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||
|
||||
cond = conds[i]
|
||||
default_c = []
|
||||
if cond is not None:
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if 'default' in x:
|
||||
default_c.append(x)
|
||||
has_default_conds = True
|
||||
continue
|
||||
p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
default_conds.append(default_c)
|
||||
|
||||
to_run += [(p, i)]
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep)
|
||||
|
||||
while len(to_run) > 0:
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]):
|
||||
to_batch_temp += [x]
|
||||
model.current_patcher.prepare_state(timestep)
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
while len(to_run) > 0:
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]):
|
||||
to_batch_temp += [x]
|
||||
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
area = []
|
||||
control = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
uuids = []
|
||||
area = []
|
||||
control = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
uuids.append(p.uuid)
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
transformer_options = {}
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = model_options['transformer_options'].copy()
|
||||
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||
model_options['transformer_options'],
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
transformer_options["patches"] = cur_patches
|
||||
if patches is not None:
|
||||
# TODO: replace with merge_nested_dicts function
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
transformer_options["patches"] = cur_patches
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
@@ -500,10 +585,15 @@ def calculate_start_end_timesteps(model, conds):
|
||||
|
||||
timestep_start = None
|
||||
timestep_end = None
|
||||
if 'start_percent' in x:
|
||||
timestep_start = s.percent_to_sigma(x['start_percent'])
|
||||
if 'end_percent' in x:
|
||||
timestep_end = s.percent_to_sigma(x['end_percent'])
|
||||
# handle clip hook schedule, if needed
|
||||
if 'clip_start_percent' in x:
|
||||
timestep_start = s.percent_to_sigma(max(x['clip_start_percent'], x.get('start_percent', 0.0)))
|
||||
timestep_end = s.percent_to_sigma(min(x['clip_end_percent'], x.get('end_percent', 1.0)))
|
||||
else:
|
||||
if 'start_percent' in x:
|
||||
timestep_start = s.percent_to_sigma(x['start_percent'])
|
||||
if 'end_percent' in x:
|
||||
timestep_end = s.percent_to_sigma(x['end_percent'])
|
||||
|
||||
if (timestep_start is not None) or (timestep_end is not None):
|
||||
n = x.copy()
|
||||
@@ -673,6 +763,12 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
if k != kk:
|
||||
create_cond_with_same_area_if_none(conds[kk], c)
|
||||
|
||||
for k in conds:
|
||||
for c in conds[k]:
|
||||
if 'hooks' in c:
|
||||
for hook in c['hooks'].hooks:
|
||||
hook.initialize_timesteps(model)
|
||||
|
||||
for k in conds:
|
||||
pre_run_control(model, conds[k])
|
||||
|
||||
@@ -685,9 +781,46 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
|
||||
return conds
|
||||
|
||||
|
||||
def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
|
||||
# determine which ControlNets have extra_hooks that should be combined with normal hooks
|
||||
hook_replacement: dict[tuple[ControlBase, comfy.hooks.HookGroup], list[dict]] = {}
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
if 'control' in kk:
|
||||
control: 'ControlBase' = kk['control']
|
||||
extra_hooks = control.get_extra_hooks()
|
||||
if len(extra_hooks) > 0:
|
||||
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
|
||||
to_replace = hook_replacement.setdefault((control, hooks), [])
|
||||
to_replace.append(kk)
|
||||
# if nothing to replace, do nothing
|
||||
if len(hook_replacement) == 0:
|
||||
return
|
||||
|
||||
# for optimal sampling performance, common ControlNets + hook combos should have identical hooks
|
||||
# on the cond dicts
|
||||
for key, conds_to_modify in hook_replacement.items():
|
||||
control = key[0]
|
||||
hooks = key[1]
|
||||
hooks = comfy.hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks])
|
||||
# if combined hooks are not None, set as new hooks for all relevant conds
|
||||
if hooks is not None:
|
||||
for cond in conds_to_modify:
|
||||
cond['hooks'] = hooks
|
||||
|
||||
|
||||
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
||||
hooks_set = set()
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
hooks_set.add(kk.get('hooks', None))
|
||||
return len(hooks_set)
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher):
|
||||
self.model_patcher = model_patcher
|
||||
self.model_patcher: 'ModelPatcher' = model_patcher
|
||||
self.model_options = model_patcher.model_options
|
||||
self.original_conds = {}
|
||||
self.cfg = 1.0
|
||||
@@ -714,19 +847,17 @@ class CFGGuider:
|
||||
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||
|
||||
extra_args = {"model_options": self.model_options, "seed":seed}
|
||||
extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}
|
||||
|
||||
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
sampler.sample,
|
||||
sampler,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True)
|
||||
)
|
||||
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
if sigmas.shape[-1] == 0:
|
||||
return latent_image
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
@@ -737,14 +868,48 @@ class CFGGuider:
|
||||
latent_image = latent_image.to(device)
|
||||
sigmas = sigmas.to(device)
|
||||
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
del self.conds
|
||||
del self.loaded_models
|
||||
return output
|
||||
|
||||
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
if sigmas.shape[-1] == 0:
|
||||
return latent_image
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
preprocess_conds_hooks(self.conds)
|
||||
|
||||
try:
|
||||
orig_model_options = self.model_options
|
||||
self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||
# if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step)
|
||||
orig_hook_mode = self.model_patcher.hook_mode
|
||||
if get_total_hook_groups_in_conds(self.conds) <= 1:
|
||||
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self.outer_sample,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
|
||||
)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
self.model_options = orig_model_options
|
||||
self.model_patcher.hook_mode = orig_hook_mode
|
||||
self.model_patcher.restore_hook_patches()
|
||||
|
||||
del self.conds
|
||||
return output
|
||||
|
||||
|
||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
cfg_guider = CFGGuider(model)
|
||||
|
||||
129
comfy/sd.py
129
comfy/sd.py
@@ -1,13 +1,16 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.utils import ProgressBar
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .ldm.cascade.stage_a import StageA
|
||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
import comfy.ldm.genmo.vae.model
|
||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import yaml
|
||||
|
||||
import comfy.utils
|
||||
@@ -27,12 +30,17 @@ import comfy.text_encoders.hydit
|
||||
import comfy.text_encoders.flux
|
||||
import comfy.text_encoders.long_clipl
|
||||
import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
import comfy.lora_convert
|
||||
import comfy.hooks
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.taesd.taesd
|
||||
|
||||
import comfy.ldm.flux.redux
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
@@ -40,6 +48,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
if clip is not None:
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
lora = comfy.lora_convert.convert_lora(lora)
|
||||
loaded = comfy.lora.load_lora(lora, key_map)
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
@@ -92,9 +101,13 @@ class CLIP:
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
self.patcher.is_clip = True
|
||||
self.apply_hooks_to_conds = None
|
||||
if params['device'] == load_device:
|
||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||
self.layer_idx = None
|
||||
self.use_clip_schedule = False
|
||||
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
||||
|
||||
def clone(self):
|
||||
@@ -103,6 +116,8 @@ class CLIP:
|
||||
n.cond_stage_model = self.cond_stage_model
|
||||
n.tokenizer = self.tokenizer
|
||||
n.layer_idx = self.layer_idx
|
||||
n.use_clip_schedule = self.use_clip_schedule
|
||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||
return n
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
@@ -114,6 +129,69 @@ class CLIP:
|
||||
def tokenize(self, text, return_word_ids=False):
|
||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
def add_hooks_to_dict(self, pooled_dict: dict[str]):
|
||||
if self.apply_hooks_to_conds:
|
||||
pooled_dict["hooks"] = self.apply_hooks_to_conds
|
||||
return pooled_dict
|
||||
|
||||
def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]={}, show_pbar=True):
|
||||
all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = []
|
||||
all_hooks = self.patcher.forced_hooks
|
||||
if all_hooks is None or not self.use_clip_schedule:
|
||||
# if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict
|
||||
return_pooled = "unprojected" if unprojected else True
|
||||
pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True)
|
||||
cond = pooled_dict.pop("cond")
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
else:
|
||||
scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule()
|
||||
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
if self.layer_idx is not None:
|
||||
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
|
||||
if unprojected:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
all_hooks.reset()
|
||||
self.patcher.patch_hooks(None)
|
||||
if show_pbar:
|
||||
pbar = ProgressBar(len(scheduled_keyframes))
|
||||
|
||||
for scheduled_opts in scheduled_keyframes:
|
||||
t_range = scheduled_opts[0]
|
||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||
if "start_percent" in add_dict:
|
||||
if t_range[1] < add_dict["start_percent"]:
|
||||
continue
|
||||
if "end_percent" in add_dict:
|
||||
if t_range[0] > add_dict["end_percent"]:
|
||||
continue
|
||||
hooks_keyframes = scheduled_opts[1]
|
||||
for hook, keyframe in hooks_keyframes:
|
||||
hook.hook_keyframe._current_keyframe = keyframe
|
||||
# apply appropriate hooks with values that match new hook_keyframe
|
||||
self.patcher.patch_hooks(all_hooks)
|
||||
# perform encoding as normal
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
pooled_dict = {"pooled_output": pooled}
|
||||
# add clip_start_percent and clip_end_percent in pooled
|
||||
pooled_dict["clip_start_percent"] = t_range[0]
|
||||
pooled_dict["clip_end_percent"] = t_range[1]
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
# add hooks stored on clip
|
||||
self.add_hooks_to_dict(pooled_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
if show_pbar:
|
||||
pbar.update(1)
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
all_hooks.reset()
|
||||
return all_cond_pooled
|
||||
|
||||
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
@@ -131,6 +209,7 @@ class CLIP:
|
||||
if len(o) > 2:
|
||||
for k in o[2]:
|
||||
out[k] = o[2][k]
|
||||
self.add_hooks_to_dict(out)
|
||||
return out
|
||||
|
||||
if return_pooled:
|
||||
@@ -245,7 +324,7 @@ class VAE:
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight": #genmo mochi vae
|
||||
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
|
||||
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
||||
if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd:
|
||||
@@ -257,6 +336,14 @@ class VAE:
|
||||
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||
self.working_dtypes = [torch.float16, torch.float32]
|
||||
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
||||
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE()
|
||||
self.latent_channels = 128
|
||||
self.latent_dim = 3
|
||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@@ -356,15 +443,33 @@ class VAE:
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
pixel_samples = self.decode_tiled_3d(samples_in)
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||
return output.movedim(1,-1)
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
dims = samples.ndim - 2
|
||||
args = {}
|
||||
if tile_x is not None:
|
||||
args["tile_x"] = tile_x
|
||||
if tile_y is not None:
|
||||
args["tile_y"] = tile_y
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
@@ -404,6 +509,12 @@ class VAE:
|
||||
def get_sd(self):
|
||||
return self.first_stage_model.state_dict()
|
||||
|
||||
def spacial_compression_decode(self):
|
||||
try:
|
||||
return self.upscale_ratio[-1]
|
||||
except:
|
||||
return self.upscale_ratio
|
||||
|
||||
class StyleModel:
|
||||
def __init__(self, model, device="cpu"):
|
||||
self.model = model
|
||||
@@ -417,6 +528,8 @@ def load_style_model(ckpt_path):
|
||||
keys = model_data.keys()
|
||||
if "style_embedding" in keys:
|
||||
model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
||||
elif "redux_down.weight" in keys:
|
||||
model = comfy.ldm.flux.redux.ReduxImageEncoder()
|
||||
else:
|
||||
raise Exception("invalid style model {}".format(ckpt_path))
|
||||
model.load_state_dict(model_data)
|
||||
@@ -430,6 +543,7 @@ class CLIPType(Enum):
|
||||
HUNYUAN_DIT = 5
|
||||
FLUX = 6
|
||||
MOCHI = 7
|
||||
LTXV = 8
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = []
|
||||
@@ -508,6 +622,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.LTXV:
|
||||
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||
|
||||
@@ -11,6 +11,7 @@ import comfy.text_encoders.aura_t5
|
||||
import comfy.text_encoders.hydit
|
||||
import comfy.text_encoders.flux
|
||||
import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -197,6 +198,8 @@ class SDXL(supported_models_base.BASE):
|
||||
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
|
||||
return model_base.ModelType.V_PREDICTION_EDM
|
||||
elif "v_pred" in state_dict:
|
||||
if "ztsnr" in state_dict: #Some zsnr anime checkpoints
|
||||
self.sampling_settings["zsnr"] = True
|
||||
return model_base.ModelType.V_PREDICTION
|
||||
else:
|
||||
return model_base.ModelType.EPS
|
||||
@@ -656,6 +659,15 @@ class Flux(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
||||
|
||||
class FluxInpaint(Flux):
|
||||
unet_config = {
|
||||
"image_model": "flux",
|
||||
"guidance_embed": True,
|
||||
"in_channels": 96,
|
||||
}
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
class FluxSchnell(Flux):
|
||||
unet_config = {
|
||||
"image_model": "flux",
|
||||
@@ -700,7 +712,34 @@ class GenmoMochi(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
|
||||
|
||||
class LTXV(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "ltxv",
|
||||
}
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi]
|
||||
sampling_settings = {
|
||||
"shift": 2.37,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.LTXV
|
||||
|
||||
memory_usage_factor = 2.7
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXV(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
18
comfy/text_encoders/lt.py
Normal file
18
comfy/text_encoders/lt.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
from transformers import T5TokenizerFast
|
||||
import comfy.text_encoders.genmo
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
|
||||
|
||||
|
||||
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
def ltxv_te(*args, **kwargs):
|
||||
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
||||
@@ -848,3 +848,24 @@ class ProgressBar:
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
||||
def reshape_mask(input_mask, output_shape):
|
||||
dims = len(output_shape) - 2
|
||||
|
||||
if dims == 1:
|
||||
scale_mode = "linear"
|
||||
|
||||
if dims == 2:
|
||||
input_mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
|
||||
scale_mode = "bilinear"
|
||||
|
||||
if dims == 3:
|
||||
if len(input_mask.shape) < 5:
|
||||
input_mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
|
||||
scale_mode = "trilinear"
|
||||
|
||||
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
|
||||
if mask.shape[1] < output_shape[1]:
|
||||
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
|
||||
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
|
||||
return mask
|
||||
|
||||
39
comfy_execution/validation.py
Normal file
39
comfy_execution/validation.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def validate_node_input(
|
||||
received_type: str, input_type: str, strict: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
received_type and input_type are both strings of the form "T1,T2,...".
|
||||
|
||||
If strict is True, the input_type must contain the received_type.
|
||||
For example, if received_type is "STRING" and input_type is "STRING,INT",
|
||||
this will return True. But if received_type is "STRING,INT" and input_type is
|
||||
"INT", this will return False.
|
||||
|
||||
If strict is False, the input_type must have overlap with the received_type.
|
||||
For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT",
|
||||
this will return True.
|
||||
|
||||
Supports pre-union type extension behaviour of ``__ne__`` overrides.
|
||||
"""
|
||||
# If the types are exactly the same, we can return immediately
|
||||
# Use pre-union behaviour: inverse of `__ne__`
|
||||
if not received_type != input_type:
|
||||
return True
|
||||
|
||||
# Not equal, and not strings
|
||||
if not isinstance(received_type, str) or not isinstance(input_type, str):
|
||||
return False
|
||||
|
||||
# Split the type strings into sets for comparison
|
||||
received_types = set(t.strip() for t in received_type.split(","))
|
||||
input_types = set(t.strip() for t in input_type.split(","))
|
||||
|
||||
if strict:
|
||||
# In strict mode, all received types must be in the input types
|
||||
return received_types.issubset(input_types)
|
||||
else:
|
||||
# In non-strict mode, there must be at least one type in common
|
||||
return len(received_types.intersection(input_types)) > 0
|
||||
@@ -17,8 +17,7 @@ class CLIPTextEncodeSDXLRefiner:
|
||||
|
||||
def encode(self, clip, ascore, width, height, text):
|
||||
tokens = clip.tokenize(text)
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled, "aesthetic_score": ascore, "width": width,"height": height}]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), )
|
||||
|
||||
class CLIPTextEncodeSDXL:
|
||||
@classmethod
|
||||
@@ -47,8 +46,7 @@ class CLIPTextEncodeSDXL:
|
||||
tokens["l"] += empty["l"]
|
||||
while len(tokens["l"]) > len(tokens["g"]):
|
||||
tokens["g"] += empty["g"]
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner,
|
||||
|
||||
@@ -18,10 +18,7 @@ class CLIPTextEncodeFlux:
|
||||
tokens = clip.tokenize(clip_l)
|
||||
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||
|
||||
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
||||
cond = output.pop("cond")
|
||||
output["guidance"] = guidance
|
||||
return ([[cond, output]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), )
|
||||
|
||||
class FluxGuidance:
|
||||
@classmethod
|
||||
|
||||
745
comfy_extras/nodes_hooks.py
Normal file
745
comfy_extras/nodes_hooks.py
Normal file
@@ -0,0 +1,745 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Union
|
||||
import torch
|
||||
from collections.abc import Iterable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.sd import CLIP
|
||||
|
||||
import comfy.hooks
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
|
||||
###########################################
|
||||
# Mask, Combine, and Hook Conditioning
|
||||
#------------------------------------------
|
||||
class PairConditioningSetProperties:
|
||||
NodeId = 'PairConditioningSetProperties'
|
||||
NodeName = 'Cond Pair Set Props'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"positive_NEW": ("CONDITIONING", ),
|
||||
"negative_NEW": ("CONDITIONING", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"set_cond_area": (["default", "mask bounds"],),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
"hooks": ("HOOKS",),
|
||||
"timesteps": ("TIMESTEPS_RANGE",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
CATEGORY = "advanced/hooks/cond pair"
|
||||
FUNCTION = "set_properties"
|
||||
|
||||
def set_properties(self, positive_NEW, negative_NEW,
|
||||
strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||
final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_NEW, negative_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||
return (final_positive, final_negative)
|
||||
|
||||
class PairConditioningSetPropertiesAndCombine:
|
||||
NodeId = 'PairConditioningSetPropertiesAndCombine'
|
||||
NodeName = 'Cond Pair Set Props Combine'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"positive_NEW": ("CONDITIONING", ),
|
||||
"negative_NEW": ("CONDITIONING", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"set_cond_area": (["default", "mask bounds"],),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
"hooks": ("HOOKS",),
|
||||
"timesteps": ("TIMESTEPS_RANGE",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
CATEGORY = "advanced/hooks/cond pair"
|
||||
FUNCTION = "set_properties"
|
||||
|
||||
def set_properties(self, positive, negative, positive_NEW, negative_NEW,
|
||||
strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||
return (final_positive, final_negative)
|
||||
|
||||
class ConditioningSetProperties:
|
||||
NodeId = 'ConditioningSetProperties'
|
||||
NodeName = 'Cond Set Props'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"cond_NEW": ("CONDITIONING", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"set_cond_area": (["default", "mask bounds"],),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
"hooks": ("HOOKS",),
|
||||
"timesteps": ("TIMESTEPS_RANGE",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
CATEGORY = "advanced/hooks/cond single"
|
||||
FUNCTION = "set_properties"
|
||||
|
||||
def set_properties(self, cond_NEW,
|
||||
strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||
(final_cond,) = comfy.hooks.set_conds_props(conds=[cond_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||
return (final_cond,)
|
||||
|
||||
class ConditioningSetPropertiesAndCombine:
|
||||
NodeId = 'ConditioningSetPropertiesAndCombine'
|
||||
NodeName = 'Cond Set Props Combine'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"cond": ("CONDITIONING", ),
|
||||
"cond_NEW": ("CONDITIONING", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"set_cond_area": (["default", "mask bounds"],),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
"hooks": ("HOOKS",),
|
||||
"timesteps": ("TIMESTEPS_RANGE",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
CATEGORY = "advanced/hooks/cond single"
|
||||
FUNCTION = "set_properties"
|
||||
|
||||
def set_properties(self, cond, cond_NEW,
|
||||
strength: float, set_cond_area: str,
|
||||
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
||||
(final_cond,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_NEW],
|
||||
strength=strength, set_cond_area=set_cond_area,
|
||||
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||
return (final_cond,)
|
||||
|
||||
class PairConditioningCombine:
|
||||
NodeId = 'PairConditioningCombine'
|
||||
NodeName = 'Cond Pair Combine'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"positive_A": ("CONDITIONING",),
|
||||
"negative_A": ("CONDITIONING",),
|
||||
"positive_B": ("CONDITIONING",),
|
||||
"negative_B": ("CONDITIONING",),
|
||||
},
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
CATEGORY = "advanced/hooks/cond pair"
|
||||
FUNCTION = "combine"
|
||||
|
||||
def combine(self, positive_A, negative_A, positive_B, negative_B):
|
||||
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],)
|
||||
return (final_positive, final_negative,)
|
||||
|
||||
class PairConditioningSetDefaultAndCombine:
|
||||
NodeId = 'PairConditioningSetDefaultCombine'
|
||||
NodeName = 'Cond Pair Set Default Combine'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"positive_DEFAULT": ("CONDITIONING",),
|
||||
"negative_DEFAULT": ("CONDITIONING",),
|
||||
},
|
||||
"optional": {
|
||||
"hooks": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
CATEGORY = "advanced/hooks/cond pair"
|
||||
FUNCTION = "set_default_and_combine"
|
||||
|
||||
def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT,
|
||||
hooks: comfy.hooks.HookGroup=None):
|
||||
final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
|
||||
hooks=hooks)
|
||||
return (final_positive, final_negative)
|
||||
|
||||
class ConditioningSetDefaultAndCombine:
|
||||
NodeId = 'ConditioningSetDefaultCombine'
|
||||
NodeName = 'Cond Set Default Combine'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"cond": ("CONDITIONING",),
|
||||
"cond_DEFAULT": ("CONDITIONING",),
|
||||
},
|
||||
"optional": {
|
||||
"hooks": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
CATEGORY = "advanced/hooks/cond single"
|
||||
FUNCTION = "set_default_and_combine"
|
||||
|
||||
def set_default_and_combine(self, cond, cond_DEFAULT,
|
||||
hooks: comfy.hooks.HookGroup=None):
|
||||
(final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
|
||||
hooks=hooks)
|
||||
return (final_conditioning,)
|
||||
|
||||
class SetClipHooks:
|
||||
NodeId = 'SetClipHooks'
|
||||
NodeName = 'Set CLIP Hooks'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"clip": ("CLIP",),
|
||||
"apply_to_conds": ("BOOLEAN", {"default": True}),
|
||||
"schedule_clip": ("BOOLEAN", {"default": False})
|
||||
},
|
||||
"optional": {
|
||||
"hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
CATEGORY = "advanced/hooks/clip"
|
||||
FUNCTION = "apply_hooks"
|
||||
|
||||
def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
||||
if hooks is not None:
|
||||
clip = clip.clone()
|
||||
if apply_to_conds:
|
||||
clip.apply_hooks_to_conds = hooks
|
||||
clip.patcher.forced_hooks = hooks.clone()
|
||||
clip.use_clip_schedule = schedule_clip
|
||||
if not clip.use_clip_schedule:
|
||||
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
|
||||
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip)
|
||||
return (clip,)
|
||||
|
||||
class ConditioningTimestepsRange:
|
||||
NodeId = 'ConditioningTimestepsRange'
|
||||
NodeName = 'Timesteps Range'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
},
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE")
|
||||
RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE")
|
||||
CATEGORY = "advanced/hooks"
|
||||
FUNCTION = "create_range"
|
||||
|
||||
def create_range(self, start_percent: float, end_percent: float):
|
||||
return ((start_percent, end_percent), (0.0, start_percent), (end_percent, 1.0))
|
||||
#------------------------------------------
|
||||
###########################################
|
||||
|
||||
|
||||
###########################################
|
||||
# Create Hooks
|
||||
#------------------------------------------
|
||||
class CreateHookLora:
|
||||
NodeId = 'CreateHookLora'
|
||||
NodeName = 'Create Hook LoRA'
|
||||
def __init__(self):
|
||||
self.loaded_lora = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
FUNCTION = "create_hook"
|
||||
|
||||
def create_hook(self, lora_name: str, strength_model: float, strength_clip: float, prev_hooks: comfy.hooks.HookGroup=None):
|
||||
if prev_hooks is None:
|
||||
prev_hooks = comfy.hooks.HookGroup()
|
||||
prev_hooks.clone()
|
||||
|
||||
if strength_model == 0 and strength_clip == 0:
|
||||
return (prev_hooks,)
|
||||
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
lora = None
|
||||
if self.loaded_lora is not None:
|
||||
if self.loaded_lora[0] == lora_path:
|
||||
lora = self.loaded_lora[1]
|
||||
else:
|
||||
temp = self.loaded_lora
|
||||
self.loaded_lora = None
|
||||
del temp
|
||||
|
||||
if lora is None:
|
||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
self.loaded_lora = (lora_path, lora)
|
||||
|
||||
hooks = comfy.hooks.create_hook_lora(lora=lora, strength_model=strength_model, strength_clip=strength_clip)
|
||||
return (prev_hooks.clone_and_combine(hooks),)
|
||||
|
||||
class CreateHookLoraModelOnly(CreateHookLora):
|
||||
NodeId = 'CreateHookLoraModelOnly'
|
||||
NodeName = 'Create Hook LoRA (MO)'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
FUNCTION = "create_hook_model_only"
|
||||
|
||||
def create_hook_model_only(self, lora_name: str, strength_model: float, prev_hooks: comfy.hooks.HookGroup=None):
|
||||
return self.create_hook(lora_name=lora_name, strength_model=strength_model, strength_clip=0, prev_hooks=prev_hooks)
|
||||
|
||||
class CreateHookModelAsLora:
|
||||
NodeId = 'CreateHookModelAsLora'
|
||||
NodeName = 'Create Hook Model as LoRA'
|
||||
|
||||
def __init__(self):
|
||||
# when not None, will be in following format:
|
||||
# (ckpt_path: str, weights_model: dict, weights_clip: dict)
|
||||
self.loaded_weights = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
FUNCTION = "create_hook"
|
||||
|
||||
def create_hook(self, ckpt_name: str, strength_model: float, strength_clip: float,
|
||||
prev_hooks: comfy.hooks.HookGroup=None):
|
||||
if prev_hooks is None:
|
||||
prev_hooks = comfy.hooks.HookGroup()
|
||||
prev_hooks.clone()
|
||||
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
weights_model = None
|
||||
weights_clip = None
|
||||
if self.loaded_weights is not None:
|
||||
if self.loaded_weights[0] == ckpt_path:
|
||||
weights_model = self.loaded_weights[1]
|
||||
weights_clip = self.loaded_weights[2]
|
||||
else:
|
||||
temp = self.loaded_weights
|
||||
self.loaded_weights = None
|
||||
del temp
|
||||
|
||||
if weights_model is None:
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
weights_model = comfy.hooks.get_patch_weights_from_model(out[0])
|
||||
weights_clip = comfy.hooks.get_patch_weights_from_model(out[1].patcher if out[1] else out[1])
|
||||
self.loaded_weights = (ckpt_path, weights_model, weights_clip)
|
||||
|
||||
hooks = comfy.hooks.create_hook_model_as_lora(weights_model=weights_model, weights_clip=weights_clip,
|
||||
strength_model=strength_model, strength_clip=strength_clip)
|
||||
return (prev_hooks.clone_and_combine(hooks),)
|
||||
|
||||
class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora):
|
||||
NodeId = 'CreateHookModelAsLoraModelOnly'
|
||||
NodeName = 'Create Hook Model as LoRA (MO)'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hooks": ("HOOKS",)
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/create"
|
||||
FUNCTION = "create_hook_model_only"
|
||||
|
||||
def create_hook_model_only(self, ckpt_name: str, strength_model: float,
|
||||
prev_hooks: comfy.hooks.HookGroup=None):
|
||||
return self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0.0, prev_hooks=prev_hooks)
|
||||
#------------------------------------------
|
||||
###########################################
|
||||
|
||||
|
||||
###########################################
|
||||
# Schedule Hooks
|
||||
#------------------------------------------
|
||||
class SetHookKeyframes:
|
||||
NodeId = 'SetHookKeyframes'
|
||||
NodeName = 'Set Hook Keyframes'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"hooks": ("HOOKS",),
|
||||
},
|
||||
"optional": {
|
||||
"hook_kf": ("HOOK_KEYFRAMES",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/scheduling"
|
||||
FUNCTION = "set_hook_keyframes"
|
||||
|
||||
def set_hook_keyframes(self, hooks: comfy.hooks.HookGroup, hook_kf: comfy.hooks.HookKeyframeGroup=None):
|
||||
if hook_kf is not None:
|
||||
hooks = hooks.clone()
|
||||
hooks.set_keyframes_on_hooks(hook_kf=hook_kf)
|
||||
return (hooks,)
|
||||
|
||||
class CreateHookKeyframe:
|
||||
NodeId = 'CreateHookKeyframe'
|
||||
NodeName = 'Create Hook Keyframe'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"strength_mult": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||
RETURN_NAMES = ("HOOK_KF",)
|
||||
CATEGORY = "advanced/hooks/scheduling"
|
||||
FUNCTION = "create_hook_keyframe"
|
||||
|
||||
def create_hook_keyframe(self, strength_mult: float, start_percent: float, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
|
||||
if prev_hook_kf is None:
|
||||
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
|
||||
prev_hook_kf = prev_hook_kf.clone()
|
||||
keyframe = comfy.hooks.HookKeyframe(strength=strength_mult, start_percent=start_percent)
|
||||
prev_hook_kf.add(keyframe)
|
||||
return (prev_hook_kf,)
|
||||
|
||||
class CreateHookKeyframesInterpolated:
|
||||
NodeId = 'CreateHookKeyframesInterpolated'
|
||||
NodeName = 'Create Hook Keyframes Interp.'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"strength_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
||||
"strength_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
||||
"interpolation": (comfy.hooks.InterpolationMethod._LIST, ),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"keyframes_count": ("INT", {"default": 5, "min": 2, "max": 100, "step": 1}),
|
||||
"print_keyframes": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||
},
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||
RETURN_NAMES = ("HOOK_KF",)
|
||||
CATEGORY = "advanced/hooks/scheduling"
|
||||
FUNCTION = "create_hook_keyframes"
|
||||
|
||||
def create_hook_keyframes(self, strength_start: float, strength_end: float, interpolation: str,
|
||||
start_percent: float, end_percent: float, keyframes_count: int,
|
||||
print_keyframes=False, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
|
||||
if prev_hook_kf is None:
|
||||
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
|
||||
prev_hook_kf = prev_hook_kf.clone()
|
||||
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=keyframes_count,
|
||||
method=comfy.hooks.InterpolationMethod.LINEAR)
|
||||
strengths = comfy.hooks.InterpolationMethod.get_weights(num_from=strength_start, num_to=strength_end, length=keyframes_count, method=interpolation)
|
||||
|
||||
is_first = True
|
||||
for percent, strength in zip(percents, strengths):
|
||||
guarantee_steps = 0
|
||||
if is_first:
|
||||
guarantee_steps = 1
|
||||
is_first = False
|
||||
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
|
||||
if print_keyframes:
|
||||
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
|
||||
return (prev_hook_kf,)
|
||||
|
||||
class CreateHookKeyframesFromFloats:
|
||||
NodeId = 'CreateHookKeyframesFromFloats'
|
||||
NodeName = 'Create Hook Keyframes From Floats'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"floats_strength": ("FLOATS", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"print_keyframes": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||
RETURN_NAMES = ("HOOK_KF",)
|
||||
CATEGORY = "advanced/hooks/scheduling"
|
||||
FUNCTION = "create_hook_keyframes"
|
||||
|
||||
def create_hook_keyframes(self, floats_strength: Union[float, list[float]],
|
||||
start_percent: float, end_percent: float,
|
||||
prev_hook_kf: comfy.hooks.HookKeyframeGroup=None, print_keyframes=False):
|
||||
if prev_hook_kf is None:
|
||||
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
|
||||
prev_hook_kf = prev_hook_kf.clone()
|
||||
if type(floats_strength) in (float, int):
|
||||
floats_strength = [float(floats_strength)]
|
||||
elif isinstance(floats_strength, Iterable):
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.")
|
||||
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength),
|
||||
method=comfy.hooks.InterpolationMethod.LINEAR)
|
||||
|
||||
is_first = True
|
||||
for percent, strength in zip(percents, floats_strength):
|
||||
guarantee_steps = 0
|
||||
if is_first:
|
||||
guarantee_steps = 1
|
||||
is_first = False
|
||||
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
|
||||
if print_keyframes:
|
||||
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
|
||||
return (prev_hook_kf,)
|
||||
#------------------------------------------
|
||||
###########################################
|
||||
|
||||
|
||||
class SetModelHooksOnCond:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"conditioning": ("CONDITIONING",),
|
||||
"hooks": ("HOOKS",),
|
||||
},
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
CATEGORY = "advanced/hooks/manual"
|
||||
FUNCTION = "attach_hook"
|
||||
|
||||
def attach_hook(self, conditioning, hooks: comfy.hooks.HookGroup):
|
||||
return (comfy.hooks.set_hooks_for_conditioning(conditioning, hooks),)
|
||||
|
||||
|
||||
###########################################
|
||||
# Combine Hooks
|
||||
#------------------------------------------
|
||||
class CombineHooks:
|
||||
NodeId = 'CombineHooks2'
|
||||
NodeName = 'Combine Hooks [2]'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
},
|
||||
"optional": {
|
||||
"hooks_A": ("HOOKS",),
|
||||
"hooks_B": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/combine"
|
||||
FUNCTION = "combine_hooks"
|
||||
|
||||
def combine_hooks(self,
|
||||
hooks_A: comfy.hooks.HookGroup=None,
|
||||
hooks_B: comfy.hooks.HookGroup=None):
|
||||
candidates = [hooks_A, hooks_B]
|
||||
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
|
||||
|
||||
class CombineHooksFour:
|
||||
NodeId = 'CombineHooks4'
|
||||
NodeName = 'Combine Hooks [4]'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
},
|
||||
"optional": {
|
||||
"hooks_A": ("HOOKS",),
|
||||
"hooks_B": ("HOOKS",),
|
||||
"hooks_C": ("HOOKS",),
|
||||
"hooks_D": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/combine"
|
||||
FUNCTION = "combine_hooks"
|
||||
|
||||
def combine_hooks(self,
|
||||
hooks_A: comfy.hooks.HookGroup=None,
|
||||
hooks_B: comfy.hooks.HookGroup=None,
|
||||
hooks_C: comfy.hooks.HookGroup=None,
|
||||
hooks_D: comfy.hooks.HookGroup=None):
|
||||
candidates = [hooks_A, hooks_B, hooks_C, hooks_D]
|
||||
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
|
||||
|
||||
class CombineHooksEight:
|
||||
NodeId = 'CombineHooks8'
|
||||
NodeName = 'Combine Hooks [8]'
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
},
|
||||
"optional": {
|
||||
"hooks_A": ("HOOKS",),
|
||||
"hooks_B": ("HOOKS",),
|
||||
"hooks_C": ("HOOKS",),
|
||||
"hooks_D": ("HOOKS",),
|
||||
"hooks_E": ("HOOKS",),
|
||||
"hooks_F": ("HOOKS",),
|
||||
"hooks_G": ("HOOKS",),
|
||||
"hooks_H": ("HOOKS",),
|
||||
}
|
||||
}
|
||||
|
||||
EXPERIMENTAL = True
|
||||
RETURN_TYPES = ("HOOKS",)
|
||||
CATEGORY = "advanced/hooks/combine"
|
||||
FUNCTION = "combine_hooks"
|
||||
|
||||
def combine_hooks(self,
|
||||
hooks_A: comfy.hooks.HookGroup=None,
|
||||
hooks_B: comfy.hooks.HookGroup=None,
|
||||
hooks_C: comfy.hooks.HookGroup=None,
|
||||
hooks_D: comfy.hooks.HookGroup=None,
|
||||
hooks_E: comfy.hooks.HookGroup=None,
|
||||
hooks_F: comfy.hooks.HookGroup=None,
|
||||
hooks_G: comfy.hooks.HookGroup=None,
|
||||
hooks_H: comfy.hooks.HookGroup=None):
|
||||
candidates = [hooks_A, hooks_B, hooks_C, hooks_D, hooks_E, hooks_F, hooks_G, hooks_H]
|
||||
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
|
||||
#------------------------------------------
|
||||
###########################################
|
||||
|
||||
node_list = [
|
||||
# Create
|
||||
CreateHookLora,
|
||||
CreateHookLoraModelOnly,
|
||||
CreateHookModelAsLora,
|
||||
CreateHookModelAsLoraModelOnly,
|
||||
# Scheduling
|
||||
SetHookKeyframes,
|
||||
CreateHookKeyframe,
|
||||
CreateHookKeyframesInterpolated,
|
||||
CreateHookKeyframesFromFloats,
|
||||
# Combine
|
||||
CombineHooks,
|
||||
CombineHooksFour,
|
||||
CombineHooksEight,
|
||||
# Attach
|
||||
ConditioningSetProperties,
|
||||
ConditioningSetPropertiesAndCombine,
|
||||
PairConditioningSetProperties,
|
||||
PairConditioningSetPropertiesAndCombine,
|
||||
ConditioningSetDefaultAndCombine,
|
||||
PairConditioningSetDefaultAndCombine,
|
||||
PairConditioningCombine,
|
||||
SetClipHooks,
|
||||
# Other
|
||||
ConditioningTimestepsRange,
|
||||
]
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
|
||||
for node in node_list:
|
||||
NODE_CLASS_MAPPINGS[node.NodeId] = node
|
||||
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName
|
||||
@@ -15,9 +15,7 @@ class CLIPTextEncodeHunyuanDiT:
|
||||
tokens = clip.tokenize(bert)
|
||||
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
|
||||
|
||||
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
||||
cond = output.pop("cond")
|
||||
return ([[cond, output]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
181
comfy_extras/nodes_lt.py
Normal file
181
comfy_extras/nodes_lt.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.model_sampling
|
||||
import math
|
||||
|
||||
class EmptyLTXVLatentVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/video/ltxv"
|
||||
|
||||
def generate(self, width, height, length, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples": latent}, )
|
||||
|
||||
|
||||
class LTXVImgToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE",),
|
||||
"image": ("IMAGE",),
|
||||
"width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
FUNCTION = "generate"
|
||||
|
||||
def generate(self, positive, negative, image, vae, width, height, length, batch_size):
|
||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t})
|
||||
|
||||
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||
latent[:, :, :t.shape[2]] = t
|
||||
return (positive, negative, {"samples": latent}, )
|
||||
|
||||
|
||||
class LTXVConditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def append(self, positive, negative, frame_rate):
|
||||
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
|
||||
return (positive, negative)
|
||||
|
||||
|
||||
class ModelSamplingLTXV:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
},
|
||||
"optional": {"latent": ("LATENT",), }
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, max_shift, base_shift, latent=None):
|
||||
m = model.clone()
|
||||
|
||||
if latent is None:
|
||||
tokens = 4096
|
||||
else:
|
||||
tokens = math.prod(latent["samples"].shape[2:])
|
||||
|
||||
x1 = 1024
|
||||
x2 = 4096
|
||||
mm = (max_shift - base_shift) / (x2 - x1)
|
||||
b = base_shift - mm * x1
|
||||
shift = (tokens) * mm + b
|
||||
|
||||
sampling_base = comfy.model_sampling.ModelSamplingFlux
|
||||
sampling_type = comfy.model_sampling.CONST
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_parameters(shift=shift)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
|
||||
class LTXVScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
"stretch": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Stretch the sigmas to be in the range [terminal, 1]."
|
||||
}),
|
||||
"terminal": (
|
||||
"FLOAT",
|
||||
{
|
||||
"default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01,
|
||||
"tooltip": "The terminal value of the sigmas after stretching."
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {"latent": ("LATENT",), }
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
|
||||
if latent is None:
|
||||
tokens = 4096
|
||||
else:
|
||||
tokens = math.prod(latent["samples"].shape[2:])
|
||||
|
||||
sigmas = torch.linspace(1.0, 0.0, steps + 1)
|
||||
|
||||
x1 = 1024
|
||||
x2 = 4096
|
||||
mm = (max_shift - base_shift) / (x2 - x1)
|
||||
b = base_shift - mm * x1
|
||||
sigma_shift = (tokens) * mm + b
|
||||
|
||||
power = 1
|
||||
sigmas = torch.where(
|
||||
sigmas != 0,
|
||||
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
|
||||
0,
|
||||
)
|
||||
|
||||
# Stretch sigmas so that its final value matches the given terminal value.
|
||||
if stretch:
|
||||
non_zero_mask = sigmas != 0
|
||||
non_zero_sigmas = sigmas[non_zero_mask]
|
||||
one_minus_z = 1.0 - non_zero_sigmas
|
||||
scale_factor = one_minus_z[-1] / (1.0 - terminal)
|
||||
stretched = 1.0 - (one_minus_z / scale_factor)
|
||||
sigmas[non_zero_mask] = stretched
|
||||
|
||||
return (sigmas,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
|
||||
"LTXVImgToVideo": LTXVImgToVideo,
|
||||
"ModelSamplingLTXV": ModelSamplingLTXV,
|
||||
"LTXVConditioning": LTXVConditioning,
|
||||
"LTXVScheduler": LTXVScheduler,
|
||||
}
|
||||
@@ -3,9 +3,6 @@ import torch
|
||||
import comfy.model_management
|
||||
|
||||
class EmptyMochiLatentVideo:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
@@ -15,10 +12,10 @@ class EmptyMochiLatentVideo:
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/mochi"
|
||||
CATEGORY = "latent/video"
|
||||
|
||||
def generate(self, width, height, length, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=self.device)
|
||||
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples":latent}, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@@ -26,8 +26,8 @@ class X0(comfy.model_sampling.EPS):
|
||||
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
|
||||
original_timesteps = 50
|
||||
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__(model_config)
|
||||
def __init__(self, model_config=None, zsnr=None):
|
||||
super().__init__(model_config, zsnr=zsnr)
|
||||
|
||||
self.skip_steps = self.num_timesteps // self.original_timesteps
|
||||
|
||||
@@ -51,25 +51,6 @@ class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete)
|
||||
return log_sigma.exp().to(timestep.device)
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||
|
||||
class ModelSamplingDiscrete:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -100,9 +81,7 @@ class ModelSamplingDiscrete:
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
if zsnr:
|
||||
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr)
|
||||
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
@@ -75,6 +75,34 @@ class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
|
||||
class ModelMergeAuraflow(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["init_x_linear."] = argument
|
||||
arg_dict["positional_encoding"] = argument
|
||||
arg_dict["cond_seq_linear."] = argument
|
||||
arg_dict["register_tokens"] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
|
||||
for i in range(4):
|
||||
arg_dict["double_layers.{}.".format(i)] = argument
|
||||
|
||||
for i in range(32):
|
||||
arg_dict["single_layers.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["modF."] = argument
|
||||
arg_dict["final_linear."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@@ -124,11 +152,58 @@ class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeMochiPreview(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_frequencies."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
arg_dict["t5_y_embedder."] = argument
|
||||
arg_dict["t5_yproj."] = argument
|
||||
|
||||
for i in range(48):
|
||||
arg_dict["blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["patchify_proj."] = argument
|
||||
arg_dict["adaln_single."] = argument
|
||||
arg_dict["caption_projection."] = argument
|
||||
|
||||
for i in range(28):
|
||||
arg_dict["transformer_blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["scale_shift_table"] = argument
|
||||
arg_dict["proj_out."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
"ModelMergeSDXL": ModelMergeSDXL,
|
||||
"ModelMergeSD3_2B": ModelMergeSD3_2B,
|
||||
"ModelMergeAuraflow": ModelMergeAuraflow,
|
||||
"ModelMergeFlux1": ModelMergeFlux1,
|
||||
"ModelMergeSD35_Large": ModelMergeSD35_Large,
|
||||
"ModelMergeMochiPreview": ModelMergeMochiPreview,
|
||||
"ModelMergeLTXV": ModelMergeLTXV,
|
||||
}
|
||||
|
||||
@@ -57,12 +57,24 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
||||
attn = attn.reshape(b, -1, hw1, hw2)
|
||||
# Global Average Pool
|
||||
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
|
||||
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
|
||||
|
||||
total = mask.shape[-1]
|
||||
x = round(math.sqrt((lh / lw) * total))
|
||||
xx = None
|
||||
for i in range(0, math.floor(math.sqrt(total) / 2)):
|
||||
for j in [(x + i), max(1, x - i)]:
|
||||
if total % j == 0:
|
||||
xx = j
|
||||
break
|
||||
if xx is not None:
|
||||
break
|
||||
|
||||
x = xx
|
||||
y = total // x
|
||||
|
||||
# Reshape
|
||||
mask = (
|
||||
mask.reshape(b, *mid_shape)
|
||||
mask.reshape(b, x, y)
|
||||
.unsqueeze(1)
|
||||
.type(attn.dtype)
|
||||
)
|
||||
|
||||
@@ -3,7 +3,9 @@ import comfy.sd
|
||||
import comfy.model_management
|
||||
import nodes
|
||||
import torch
|
||||
import re
|
||||
import comfy_extras.nodes_slg
|
||||
|
||||
|
||||
class TripleCLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -14,6 +16,8 @@ class TripleCLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||
@@ -21,6 +25,7 @@ class TripleCLIPLoader:
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return (clip,)
|
||||
|
||||
|
||||
class EmptySD3LatentImage:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
@@ -39,6 +44,7 @@ class EmptySD3LatentImage:
|
||||
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
|
||||
return ({"samples":latent}, )
|
||||
|
||||
|
||||
class CLIPTextEncodeSD3:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -76,8 +82,7 @@ class CLIPTextEncodeSD3:
|
||||
tokens["l"] += empty["l"]
|
||||
while len(tokens["l"]) > len(tokens["g"]):
|
||||
tokens["g"] += empty["g"]
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled}]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
|
||||
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
||||
@@ -95,7 +100,8 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
DEPRECATED = True
|
||||
|
||||
class SkipLayerGuidanceSD3:
|
||||
|
||||
class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT):
|
||||
'''
|
||||
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||
@@ -110,47 +116,12 @@ class SkipLayerGuidanceSD3:
|
||||
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "skip_guidance"
|
||||
FUNCTION = "skip_guidance_sd3"
|
||||
|
||||
CATEGORY = "advanced/guidance"
|
||||
|
||||
|
||||
def skip_guidance(self, model, layers, scale, start_percent, end_percent):
|
||||
if layers == "" or layers == None:
|
||||
return (model, )
|
||||
# check if layer is comma separated integers
|
||||
def skip(args, extra_args):
|
||||
return args
|
||||
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
def post_cfg_function(args):
|
||||
model = args["model"]
|
||||
cond_pred = args["cond_denoised"]
|
||||
cond = args["cond"]
|
||||
cfg_result = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
x = args["input"]
|
||||
model_options = args["model_options"].copy()
|
||||
|
||||
for layer in layers:
|
||||
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
|
||||
model_sampling.percent_to_sigma(start_percent)
|
||||
|
||||
sigma_ = sigma[0].item()
|
||||
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
|
||||
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||
cfg_result = cfg_result + (cond_pred - slg) * scale
|
||||
return cfg_result
|
||||
|
||||
layers = re.findall(r'\d+', layers)
|
||||
layers = [int(i) for i in layers]
|
||||
m = model.clone()
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
return (m, )
|
||||
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent):
|
||||
return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
84
comfy_extras/nodes_slg.py
Normal file
84
comfy_extras/nodes_slg.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import comfy.model_patcher
|
||||
import comfy.samplers
|
||||
import re
|
||||
|
||||
|
||||
class SkipLayerGuidanceDiT:
|
||||
'''
|
||||
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||
Original experimental implementation for SD3 by Dango233@StabilityAI.
|
||||
'''
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL", ),
|
||||
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
||||
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
||||
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "skip_guidance"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model."
|
||||
|
||||
CATEGORY = "advanced/guidance"
|
||||
|
||||
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0):
|
||||
# check if layer is comma separated integers
|
||||
def skip(args, extra_args):
|
||||
return args
|
||||
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
double_layers = re.findall(r'\d+', double_layers)
|
||||
double_layers = [int(i) for i in double_layers]
|
||||
|
||||
single_layers = re.findall(r'\d+', single_layers)
|
||||
single_layers = [int(i) for i in single_layers]
|
||||
|
||||
if len(double_layers) == 0 and len(single_layers) == 0:
|
||||
return (model, )
|
||||
|
||||
def post_cfg_function(args):
|
||||
model = args["model"]
|
||||
cond_pred = args["cond_denoised"]
|
||||
cond = args["cond"]
|
||||
cfg_result = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
x = args["input"]
|
||||
model_options = args["model_options"].copy()
|
||||
|
||||
for layer in double_layers:
|
||||
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
|
||||
|
||||
for layer in single_layers:
|
||||
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "single_block", layer)
|
||||
|
||||
model_sampling.percent_to_sigma(start_percent)
|
||||
|
||||
sigma_ = sigma[0].item()
|
||||
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
|
||||
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||
cfg_result = cfg_result + (cond_pred - slg) * scale
|
||||
if rescaling_scale != 0:
|
||||
factor = cond_pred.std() / cfg_result.std()
|
||||
factor = rescaling_scale * factor + (1 - rescaling_scale)
|
||||
cfg_result *= factor
|
||||
|
||||
return cfg_result
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
return (m, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
|
||||
}
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 116 KiB |
@@ -16,6 +16,7 @@ import comfy.model_management
|
||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy.cli_args import args
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@@ -480,7 +481,7 @@ class PromptExecutor:
|
||||
if self.caches.outputs.get(node_id) is not None:
|
||||
cached_nodes.append(node_id)
|
||||
|
||||
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
@@ -527,7 +528,6 @@ class PromptExecutor:
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
|
||||
|
||||
def validate_inputs(prompt, item, validated):
|
||||
unique_id = item
|
||||
if unique_id in validated:
|
||||
@@ -589,8 +589,8 @@ def validate_inputs(prompt, item, validated):
|
||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||
received_type = r[val[1]]
|
||||
received_types[x] = received_type
|
||||
if 'input_types' not in validate_function_inputs and received_type != type_input:
|
||||
details = f"{x}, {received_type} != {type_input}"
|
||||
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
|
||||
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
|
||||
error = {
|
||||
"type": "return_type_mismatch",
|
||||
"message": "Return type mismatch between linked nodes",
|
||||
|
||||
@@ -47,7 +47,12 @@ class Latent2RGBPreviewer(LatentPreviewer):
|
||||
if self.latent_rgb_factors_bias is not None:
|
||||
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||
|
||||
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
||||
if x0.ndim == 5:
|
||||
x0 = x0[0, :, 0]
|
||||
else:
|
||||
x0 = x0[0]
|
||||
|
||||
latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
||||
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
||||
|
||||
return preview_to_image(latent_image)
|
||||
|
||||
2
main.py
2
main.py
@@ -71,6 +71,7 @@ if os.name == "nt":
|
||||
if __name__ == "__main__":
|
||||
if args.cuda_device is not None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
||||
|
||||
if args.deterministic:
|
||||
@@ -153,7 +154,6 @@ def prompt_worker(q, server):
|
||||
if need_gc:
|
||||
current_time = time.perf_counter()
|
||||
if (current_time - last_gc_collect) > gc_collect_interval:
|
||||
comfy.model_management.cleanup_models()
|
||||
gc.collect()
|
||||
comfy.model_management.soft_empty_cache()
|
||||
last_gc_collect = current_time
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
# model_manager/__init__.py
|
||||
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename
|
||||
@@ -1,234 +0,0 @@
|
||||
#NOTE: This was an experiment and WILL BE REMOVED
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import os
|
||||
import traceback
|
||||
import logging
|
||||
from folder_paths import folder_names_and_paths, get_folder_paths
|
||||
import re
|
||||
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||
from enum import Enum
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class DownloadStatusType(Enum):
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadModelStatus():
|
||||
status: str
|
||||
progress_percentage: float
|
||||
message: str
|
||||
already_existed: bool = False
|
||||
|
||||
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
|
||||
self.status = status.value # Store the string value of the Enum
|
||||
self.progress_percentage = progress_percentage
|
||||
self.message = message
|
||||
self.already_existed = already_existed
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"status": self.status,
|
||||
"progress_percentage": self.progress_percentage,
|
||||
"message": self.message,
|
||||
"already_existed": self.already_existed
|
||||
}
|
||||
|
||||
|
||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||
model_name: str,
|
||||
model_url: str,
|
||||
model_directory: str,
|
||||
folder_path: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
progress_interval: float = 1.0) -> DownloadModelStatus:
|
||||
"""
|
||||
Download a model file from a given URL into the models directory.
|
||||
|
||||
Args:
|
||||
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
||||
A function that makes an HTTP request. This makes it easier to mock in unit tests.
|
||||
model_name (str):
|
||||
The name of the model file to be downloaded. This will be the filename on disk.
|
||||
model_url (str):
|
||||
The URL from which to download the model.
|
||||
model_directory (str):
|
||||
The subdirectory within the main models directory where the model
|
||||
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
||||
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
||||
An asynchronous function to call with progress updates.
|
||||
folder_path (str);
|
||||
Path to which model folder should be used as the root.
|
||||
|
||||
Returns:
|
||||
DownloadModelStatus: The result of the download operation.
|
||||
"""
|
||||
if not validate_filename(model_name):
|
||||
return DownloadModelStatus(
|
||||
DownloadStatusType.ERROR,
|
||||
0,
|
||||
"Invalid model name",
|
||||
False
|
||||
)
|
||||
|
||||
if not model_directory in folder_names_and_paths:
|
||||
return DownloadModelStatus(
|
||||
DownloadStatusType.ERROR,
|
||||
0,
|
||||
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
|
||||
False
|
||||
)
|
||||
|
||||
if not folder_path in get_folder_paths(model_directory):
|
||||
return DownloadModelStatus(
|
||||
DownloadStatusType.ERROR,
|
||||
0,
|
||||
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
|
||||
False
|
||||
)
|
||||
|
||||
file_path = create_model_path(model_name, folder_path)
|
||||
existing_file = await check_file_exists(file_path, model_name, progress_callback)
|
||||
if existing_file:
|
||||
return existing_file
|
||||
|
||||
try:
|
||||
logging.info(f"Downloading {model_name} from {model_url}")
|
||||
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
||||
await progress_callback(model_name, status)
|
||||
|
||||
response = await model_download_request(model_url)
|
||||
if response.status != 200:
|
||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||
logging.error(error_message)
|
||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||
await progress_callback(model_name, status)
|
||||
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||
|
||||
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in downloading model: {e}")
|
||||
return await handle_download_error(e, model_name, progress_callback)
|
||||
|
||||
|
||||
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
file_path = os.path.join(folder_path, model_name)
|
||||
|
||||
# Ensure the resulting path is still within the base directory
|
||||
abs_file_path = os.path.abspath(file_path)
|
||||
abs_base_dir = os.path.abspath(folder_path)
|
||||
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
||||
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
async def check_file_exists(file_path: str,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
|
||||
) -> Optional[DownloadModelStatus]:
|
||||
if os.path.exists(file_path):
|
||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
||||
await progress_callback(model_name, status)
|
||||
return status
|
||||
return None
|
||||
|
||||
|
||||
async def track_download_progress(response: aiohttp.ClientResponse,
|
||||
file_path: str,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
interval: float = 1.0) -> DownloadModelStatus:
|
||||
try:
|
||||
total_size = int(response.headers.get('Content-Length', 0))
|
||||
downloaded = 0
|
||||
last_update_time = time.time()
|
||||
|
||||
async def update_progress():
|
||||
nonlocal last_update_time
|
||||
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
||||
await progress_callback(model_name, status)
|
||||
last_update_time = time.time()
|
||||
|
||||
temp_file_path = file_path + '.tmp'
|
||||
with open(temp_file_path, 'wb') as f:
|
||||
chunk_iterator = response.content.iter_chunked(8192)
|
||||
while True:
|
||||
try:
|
||||
chunk = await chunk_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
|
||||
if time.time() - last_update_time >= interval:
|
||||
await update_progress()
|
||||
|
||||
os.rename(temp_file_path, file_path)
|
||||
|
||||
await update_progress()
|
||||
|
||||
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
||||
await progress_callback(model_name, status)
|
||||
|
||||
return status
|
||||
except Exception as e:
|
||||
logging.error(f"Error in track_download_progress: {e}")
|
||||
logging.error(traceback.format_exc())
|
||||
return await handle_download_error(e, model_name, progress_callback)
|
||||
|
||||
|
||||
async def handle_download_error(e: Exception,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Any]
|
||||
) -> DownloadModelStatus:
|
||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||
await progress_callback(model_name, status)
|
||||
return status
|
||||
|
||||
|
||||
def validate_filename(filename: str)-> bool:
|
||||
"""
|
||||
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
||||
|
||||
Args:
|
||||
filename (str): The filename to validate
|
||||
|
||||
Returns:
|
||||
bool: True if the filename is valid, False otherwise
|
||||
"""
|
||||
if not filename.lower().endswith(('.sft', '.safetensors')):
|
||||
return False
|
||||
|
||||
# Check if the filename is empty, None, or just whitespace
|
||||
if not filename or not filename.strip():
|
||||
return False
|
||||
|
||||
# Check for any directory traversal attempts or invalid characters
|
||||
if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']):
|
||||
return False
|
||||
|
||||
# Check if the filename starts with a dot (hidden file)
|
||||
if filename.startswith('.'):
|
||||
return False
|
||||
|
||||
# Use a whitelist of allowed characters
|
||||
if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename):
|
||||
return False
|
||||
|
||||
# Ensure the filename isn't too long
|
||||
if len(filename) > 255:
|
||||
return False
|
||||
|
||||
return True
|
||||
64
nodes.py
64
nodes.py
@@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
|
||||
import os
|
||||
@@ -24,6 +25,7 @@ import comfy.sample
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.controlnet
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
|
||||
import comfy.clip_vision
|
||||
|
||||
@@ -44,16 +46,16 @@ def interrupt_processing(value=True):
|
||||
|
||||
MAX_RESOLUTION=16384
|
||||
|
||||
class CLIPTextEncode:
|
||||
class CLIPTextEncode(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"text": ("STRING", {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
|
||||
"clip": ("CLIP", {"tooltip": "The CLIP model used for encoding the text."})
|
||||
"text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
|
||||
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
RETURN_TYPES = (IO.CONDITIONING,)
|
||||
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
@@ -62,9 +64,8 @@ class CLIPTextEncode:
|
||||
|
||||
def encode(self, clip, text):
|
||||
tokens = clip.tokenize(text)
|
||||
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
||||
cond = output.pop("cond")
|
||||
return ([[cond, output]], )
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
|
||||
class ConditioningCombine:
|
||||
@classmethod
|
||||
@@ -290,15 +291,22 @@ class VAEDecodeTiled:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
||||
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
|
||||
"tile_size": ("INT", {"default": 512, "min": 128, "max": 4096, "step": 32}),
|
||||
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def decode(self, vae, samples, tile_size):
|
||||
return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
|
||||
def decode(self, vae, samples, tile_size, overlap=64):
|
||||
if tile_size < overlap * 4:
|
||||
overlap = tile_size // 4
|
||||
compression = vae.spacial_compression_decode()
|
||||
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
|
||||
if len(images.shape) == 5: #Combine batches
|
||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||
return (images, )
|
||||
|
||||
class VAEEncode:
|
||||
@classmethod
|
||||
@@ -376,6 +384,7 @@ class InpaintModelConditioning:
|
||||
"vae": ("VAE", ),
|
||||
"pixels": ("IMAGE", ),
|
||||
"mask": ("MASK", ),
|
||||
"noise_mask": ("BOOLEAN", {"default": True, "tooltip": "Add a noise mask to the latent so sampling will only happen within the mask. Might improve results or completely break things depending on the model."}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
|
||||
@@ -384,7 +393,7 @@ class InpaintModelConditioning:
|
||||
|
||||
CATEGORY = "conditioning/inpaint"
|
||||
|
||||
def encode(self, positive, negative, pixels, vae, mask):
|
||||
def encode(self, positive, negative, pixels, vae, mask, noise_mask=True):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
||||
@@ -408,7 +417,8 @@ class InpaintModelConditioning:
|
||||
out_latent = {}
|
||||
|
||||
out_latent["samples"] = orig_latent
|
||||
out_latent["noise_mask"] = mask
|
||||
if noise_mask:
|
||||
out_latent["noise_mask"] = mask
|
||||
|
||||
out = []
|
||||
for conditioning in [positive, negative]:
|
||||
@@ -889,13 +899,15 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion"):
|
||||
if type == "stable_cascade":
|
||||
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
||||
@@ -905,6 +917,8 @@ class CLIPLoader:
|
||||
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
|
||||
elif type == "mochi":
|
||||
clip_type = comfy.sd.CLIPType.MOCHI
|
||||
elif type == "ltxv":
|
||||
clip_type = comfy.sd.CLIPType.LTXV
|
||||
else:
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
|
||||
@@ -924,6 +938,8 @@ class DualCLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, type):
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||
@@ -956,15 +972,19 @@ class CLIPVisionEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_vision": ("CLIP_VISION",),
|
||||
"image": ("IMAGE",)
|
||||
"image": ("IMAGE",),
|
||||
"crop": (["center", "none"],)
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def encode(self, clip_vision, image):
|
||||
output = clip_vision.encode_image(image)
|
||||
def encode(self, clip_vision, image, crop):
|
||||
crop_image = True
|
||||
if crop != "center":
|
||||
crop_image = False
|
||||
output = clip_vision.encode_image(image, crop=crop_image)
|
||||
return (output,)
|
||||
|
||||
class StyleModelLoader:
|
||||
@@ -989,14 +1009,19 @@ class StyleModelApply:
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
"style_model": ("STYLE_MODEL", ),
|
||||
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
|
||||
"strength_type": (["multiply"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "apply_stylemodel"
|
||||
|
||||
CATEGORY = "conditioning/style_model"
|
||||
|
||||
def apply_stylemodel(self, clip_vision_output, style_model, conditioning):
|
||||
def apply_stylemodel(self, clip_vision_output, style_model, conditioning, strength, strength_type):
|
||||
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
|
||||
if strength_type == "multiply":
|
||||
cond *= strength
|
||||
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
|
||||
@@ -2123,6 +2148,9 @@ def init_builtin_extra_nodes():
|
||||
"nodes_lora_extract.py",
|
||||
"nodes_torch_compile.py",
|
||||
"nodes_mochi.py",
|
||||
"nodes_slg.py",
|
||||
"nodes_lt.py",
|
||||
"nodes_hooks.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
33
server.py
33
server.py
@@ -29,7 +29,6 @@ import comfy.model_management
|
||||
import node_helpers
|
||||
from app.frontend_management import FrontendManager
|
||||
from app.user_manager import UserManager
|
||||
from model_filemanager import download_model, DownloadModelStatus
|
||||
from typing import Optional
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
|
||||
@@ -152,7 +151,7 @@ class PromptServer():
|
||||
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
||||
|
||||
self.user_manager = UserManager()
|
||||
self.internal_routes = InternalRoutes()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = None
|
||||
self.loop = loop
|
||||
@@ -676,36 +675,6 @@ class PromptServer():
|
||||
self.prompt_queue.delete_history_item(id_to_delete)
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
# Internal route. Should not be depended upon and is subject to change at any time.
|
||||
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
|
||||
# NOTE: This was an experiment and WILL BE REMOVED
|
||||
@routes.post("/internal/models/download")
|
||||
async def download_handler(request):
|
||||
async def report_progress(filename: str, status: DownloadModelStatus):
|
||||
payload = status.to_dict()
|
||||
payload['download_path'] = filename
|
||||
await self.send_json("download_progress", payload)
|
||||
|
||||
data = await request.json()
|
||||
url = data.get('url')
|
||||
model_directory = data.get('model_directory')
|
||||
folder_path = data.get('folder_path')
|
||||
model_filename = data.get('model_filename')
|
||||
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
||||
|
||||
if not url or not model_directory or not model_filename or not folder_path:
|
||||
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||
|
||||
session = self.client_session
|
||||
if session is None:
|
||||
logging.error("Client session is not initialized")
|
||||
return web.Response(status=500)
|
||||
|
||||
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
|
||||
await task
|
||||
|
||||
return web.json_response(task.result().to_dict())
|
||||
|
||||
async def setup(self):
|
||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||
|
||||
119
tests-unit/execution_test/validate_node_input_test.py
Normal file
119
tests-unit/execution_test/validate_node_input_test.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import pytest
|
||||
from comfy_execution.validation import validate_node_input
|
||||
|
||||
|
||||
def test_exact_match():
|
||||
"""Test cases where types match exactly"""
|
||||
assert validate_node_input("STRING", "STRING")
|
||||
assert validate_node_input("STRING,INT", "STRING,INT")
|
||||
assert validate_node_input("INT,STRING", "STRING,INT") # Order shouldn't matter
|
||||
|
||||
|
||||
def test_strict_mode():
|
||||
"""Test strict mode validation"""
|
||||
# Should pass - received type is subset of input type
|
||||
assert validate_node_input("STRING", "STRING,INT", strict=True)
|
||||
assert validate_node_input("INT", "STRING,INT", strict=True)
|
||||
assert validate_node_input("STRING,INT", "STRING,INT,BOOLEAN", strict=True)
|
||||
|
||||
# Should fail - received type is not subset of input type
|
||||
assert not validate_node_input("STRING,INT", "STRING", strict=True)
|
||||
assert not validate_node_input("STRING,BOOLEAN", "STRING", strict=True)
|
||||
assert not validate_node_input("INT,BOOLEAN", "STRING,INT", strict=True)
|
||||
|
||||
|
||||
def test_non_strict_mode():
|
||||
"""Test non-strict mode validation (default behavior)"""
|
||||
# Should pass - types have overlap
|
||||
assert validate_node_input("STRING,BOOLEAN", "STRING,INT")
|
||||
assert validate_node_input("STRING,INT", "INT,BOOLEAN")
|
||||
assert validate_node_input("STRING", "STRING,INT")
|
||||
|
||||
# Should fail - no overlap in types
|
||||
assert not validate_node_input("BOOLEAN", "STRING,INT")
|
||||
assert not validate_node_input("FLOAT", "STRING,INT")
|
||||
assert not validate_node_input("FLOAT,BOOLEAN", "STRING,INT")
|
||||
|
||||
|
||||
def test_whitespace_handling():
|
||||
"""Test that whitespace is handled correctly"""
|
||||
assert validate_node_input("STRING, INT", "STRING,INT")
|
||||
assert validate_node_input("STRING,INT", "STRING, INT")
|
||||
assert validate_node_input(" STRING , INT ", "STRING,INT")
|
||||
assert validate_node_input("STRING,INT", " STRING , INT ")
|
||||
|
||||
|
||||
def test_empty_strings():
|
||||
"""Test behavior with empty strings"""
|
||||
assert validate_node_input("", "")
|
||||
assert not validate_node_input("STRING", "")
|
||||
assert not validate_node_input("", "STRING")
|
||||
|
||||
|
||||
def test_single_vs_multiple():
|
||||
"""Test single type against multiple types"""
|
||||
assert validate_node_input("STRING", "STRING,INT,BOOLEAN")
|
||||
assert validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=False)
|
||||
assert not validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=True)
|
||||
|
||||
|
||||
def test_non_string():
|
||||
"""Test non-string types"""
|
||||
obj1 = object()
|
||||
obj2 = object()
|
||||
assert validate_node_input(obj1, obj1)
|
||||
assert not validate_node_input(obj1, obj2)
|
||||
|
||||
|
||||
class NotEqualsOverrideTest(str):
|
||||
"""Test class for ``__ne__`` override."""
|
||||
|
||||
def __ne__(self, value: object) -> bool:
|
||||
if self == "*" or value == "*":
|
||||
return False
|
||||
if self == "LONGER_THAN_2":
|
||||
return not len(value) > 2
|
||||
raise TypeError("This is a class for unit tests only.")
|
||||
|
||||
|
||||
def test_ne_override():
|
||||
"""Test ``__ne__`` any override"""
|
||||
any = NotEqualsOverrideTest("*")
|
||||
invalid_type = "INVALID_TYPE"
|
||||
obj = object()
|
||||
assert validate_node_input(any, any)
|
||||
assert validate_node_input(any, invalid_type)
|
||||
assert validate_node_input(any, obj)
|
||||
assert validate_node_input(any, {})
|
||||
assert validate_node_input(any, [])
|
||||
assert validate_node_input(any, [1, 2, 3])
|
||||
|
||||
|
||||
def test_ne_custom_override():
|
||||
"""Test ``__ne__`` custom override"""
|
||||
special = NotEqualsOverrideTest("LONGER_THAN_2")
|
||||
|
||||
assert validate_node_input(special, special)
|
||||
assert validate_node_input(special, "*")
|
||||
assert validate_node_input(special, "INVALID_TYPE")
|
||||
assert validate_node_input(special, [1, 2, 3])
|
||||
|
||||
# Should fail
|
||||
assert not validate_node_input(special, [1, 2])
|
||||
assert not validate_node_input(special, "TY")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"received,input_type,strict,expected",
|
||||
[
|
||||
("STRING", "STRING", False, True),
|
||||
("STRING,INT", "STRING,INT", False, True),
|
||||
("STRING", "STRING,INT", True, True),
|
||||
("STRING,INT", "STRING", True, False),
|
||||
("BOOLEAN", "STRING,INT", False, False),
|
||||
("STRING,BOOLEAN", "STRING,INT", False, True),
|
||||
],
|
||||
)
|
||||
def test_parametrized_cases(received, input_type, strict, expected):
|
||||
"""Parametrized test cases for various scenarios"""
|
||||
assert validate_node_input(received, input_type, strict) == expected
|
||||
@@ -1,337 +0,0 @@
|
||||
import pytest
|
||||
import tempfile
|
||||
import aiohttp
|
||||
from aiohttp import ClientResponse
|
||||
import itertools
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
||||
import folder_paths
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
yield tmpdirname
|
||||
|
||||
class AsyncIteratorMock:
|
||||
"""
|
||||
A mock class that simulates an asynchronous iterator.
|
||||
This is used to mimic the behavior of aiohttp's content iterator.
|
||||
"""
|
||||
def __init__(self, seq):
|
||||
# Convert the input sequence into an iterator
|
||||
self.iter = iter(seq)
|
||||
|
||||
def __aiter__(self):
|
||||
# This method is called when 'async for' is used
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
# This method is called for each iteration in an 'async for' loop
|
||||
try:
|
||||
return next(self.iter)
|
||||
except StopIteration:
|
||||
# This is the asynchronous equivalent of StopIteration
|
||||
raise StopAsyncIteration
|
||||
|
||||
class ContentMock:
|
||||
"""
|
||||
A mock class that simulates the content attribute of an aiohttp ClientResponse.
|
||||
This class provides the iter_chunked method which returns an async iterator of chunks.
|
||||
"""
|
||||
def __init__(self, chunks):
|
||||
# Store the chunks that will be returned by the iterator
|
||||
self.chunks = chunks
|
||||
|
||||
def iter_chunked(self, chunk_size):
|
||||
# This method mimics aiohttp's content.iter_chunked()
|
||||
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
|
||||
return AsyncIteratorMock(self.chunks)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_success(temp_dir):
|
||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.status = 200
|
||||
mock_response.headers = {'Content-Length': '1000'}
|
||||
# Create a mock for content that returns an async iterator directly
|
||||
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
|
||||
mock_response.content = ContentMock(chunks)
|
||||
|
||||
mock_make_request = AsyncMock(return_value=mock_response)
|
||||
mock_progress_callback = AsyncMock()
|
||||
|
||||
time_values = itertools.count(0, 0.1)
|
||||
|
||||
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
||||
|
||||
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
|
||||
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||
patch('folder_paths.folder_names_and_paths', fake_paths), \
|
||||
patch('time.time', side_effect=time_values): # Simulate time passing
|
||||
|
||||
result = await download_model(
|
||||
mock_make_request,
|
||||
'model.sft',
|
||||
'http://example.com/model.sft',
|
||||
'checkpoints',
|
||||
temp_dir,
|
||||
mock_progress_callback
|
||||
)
|
||||
|
||||
# Assert the result
|
||||
assert isinstance(result, DownloadModelStatus)
|
||||
assert result.message == 'Successfully downloaded model.sft'
|
||||
assert result.status == 'completed'
|
||||
assert result.already_existed is False
|
||||
|
||||
# Check progress callback calls
|
||||
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
|
||||
|
||||
# Check initial call
|
||||
mock_progress_callback.assert_any_call(
|
||||
'model.sft',
|
||||
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
||||
)
|
||||
|
||||
# Check final call
|
||||
mock_progress_callback.assert_any_call(
|
||||
'model.sft',
|
||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
||||
)
|
||||
|
||||
mock_file_path = os.path.join(temp_dir, 'model.sft')
|
||||
assert os.path.exists(mock_file_path)
|
||||
with open(mock_file_path, 'rb') as mock_file:
|
||||
assert mock_file.read() == b''.join(chunks)
|
||||
os.remove(mock_file_path)
|
||||
|
||||
# Verify request was made
|
||||
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_url_request_failure(temp_dir):
|
||||
# Mock dependencies
|
||||
mock_response = AsyncMock(spec=ClientResponse)
|
||||
mock_response.status = 404 # Simulate a "Not Found" error
|
||||
mock_get = AsyncMock(return_value=mock_response)
|
||||
mock_progress_callback = AsyncMock()
|
||||
|
||||
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
||||
|
||||
# Mock the create_model_path function
|
||||
with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
|
||||
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||
patch('folder_paths.folder_names_and_paths', fake_paths):
|
||||
# Call the function
|
||||
result = await download_model(
|
||||
mock_get,
|
||||
'model.safetensors',
|
||||
'http://example.com/model.safetensors',
|
||||
'checkpoints',
|
||||
temp_dir,
|
||||
mock_progress_callback
|
||||
)
|
||||
|
||||
# Assert the expected behavior
|
||||
assert isinstance(result, DownloadModelStatus)
|
||||
assert result.status == 'error'
|
||||
assert result.message == 'Failed to download model.safetensors. Status code: 404'
|
||||
assert result.already_existed is False
|
||||
|
||||
# Check that progress_callback was called with the correct arguments
|
||||
mock_progress_callback.assert_any_call(
|
||||
'model.safetensors',
|
||||
DownloadModelStatus(
|
||||
status=DownloadStatusType.PENDING,
|
||||
progress_percentage=0,
|
||||
message='Starting download of model.safetensors',
|
||||
already_existed=False
|
||||
)
|
||||
)
|
||||
mock_progress_callback.assert_called_with(
|
||||
'model.safetensors',
|
||||
DownloadModelStatus(
|
||||
status=DownloadStatusType.ERROR,
|
||||
progress_percentage=0,
|
||||
message='Failed to download model.safetensors. Status code: 404',
|
||||
already_existed=False
|
||||
)
|
||||
)
|
||||
|
||||
# Verify that the get method was called with the correct URL
|
||||
mock_get.assert_called_once_with('http://example.com/model.safetensors')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_invalid_model_subdirectory():
|
||||
mock_make_request = AsyncMock()
|
||||
mock_progress_callback = AsyncMock()
|
||||
|
||||
result = await download_model(
|
||||
mock_make_request,
|
||||
'model.sft',
|
||||
'http://example.com/model.sft',
|
||||
'../bad_path',
|
||||
'../bad_path',
|
||||
mock_progress_callback
|
||||
)
|
||||
|
||||
# Assert the result
|
||||
assert isinstance(result, DownloadModelStatus)
|
||||
assert result.message.startswith('Invalid or unrecognized model directory')
|
||||
assert result.status == 'error'
|
||||
assert result.already_existed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_invalid_folder_path():
|
||||
mock_make_request = AsyncMock()
|
||||
mock_progress_callback = AsyncMock()
|
||||
|
||||
result = await download_model(
|
||||
mock_make_request,
|
||||
'model.sft',
|
||||
'http://example.com/model.sft',
|
||||
'checkpoints',
|
||||
'invalid_path',
|
||||
mock_progress_callback
|
||||
)
|
||||
|
||||
# Assert the result
|
||||
assert isinstance(result, DownloadModelStatus)
|
||||
assert result.message.startswith("Invalid folder path")
|
||||
assert result.status == 'error'
|
||||
assert result.already_existed is False
|
||||
|
||||
def test_create_model_path(tmp_path, monkeypatch):
|
||||
model_name = "model.safetensors"
|
||||
folder_path = os.path.join(tmp_path, "mock_dir")
|
||||
|
||||
file_path = create_model_path(model_name, folder_path)
|
||||
|
||||
assert file_path == os.path.join(folder_path, "model.safetensors")
|
||||
assert os.path.exists(os.path.dirname(file_path))
|
||||
|
||||
with pytest.raises(Exception, match="Invalid model directory"):
|
||||
create_model_path("../path_traversal.safetensors", folder_path)
|
||||
|
||||
with pytest.raises(Exception, match="Invalid model directory"):
|
||||
create_model_path("/etc/some_root_path", folder_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_file_exists_when_file_exists(tmp_path):
|
||||
file_path = tmp_path / "existing_model.sft"
|
||||
file_path.touch() # Create an empty file
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
|
||||
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == "completed"
|
||||
assert result.message == "existing_model.sft already exists"
|
||||
assert result.already_existed is True
|
||||
|
||||
mock_callback.assert_called_once_with(
|
||||
"existing_model.sft",
|
||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
||||
file_path = tmp_path / "non_existing_model.sft"
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
|
||||
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback)
|
||||
|
||||
assert result is None
|
||||
mock_callback.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_download_progress_no_content_length(temp_dir):
|
||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.headers = {} # No Content-Length header
|
||||
chunks = [b'a' * 500, b'b' * 500]
|
||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
|
||||
full_path = os.path.join(temp_dir, 'model.sft')
|
||||
|
||||
result = await track_download_progress(
|
||||
mock_response, full_path, 'model.sft',
|
||||
mock_callback, interval=0.1
|
||||
)
|
||||
|
||||
assert result.status == "completed"
|
||||
|
||||
assert os.path.exists(full_path)
|
||||
with open(full_path, 'rb') as f:
|
||||
assert f.read() == b''.join(chunks)
|
||||
os.remove(full_path)
|
||||
|
||||
# Check that progress was reported even without knowing the total size
|
||||
mock_callback.assert_any_call(
|
||||
'model.sft',
|
||||
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_download_progress_interval(temp_dir):
|
||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.headers = {'Content-Length': '1000'}
|
||||
chunks = [b'a' * 100] * 10
|
||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
mock_open = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Create a mock time function that returns incremental float values
|
||||
mock_time = MagicMock()
|
||||
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
||||
|
||||
full_path = os.path.join(temp_dir, 'model.sft')
|
||||
|
||||
with patch('time.time', mock_time):
|
||||
await track_download_progress(
|
||||
mock_response, full_path, 'model.sft',
|
||||
mock_callback, interval=1.0
|
||||
)
|
||||
|
||||
assert os.path.exists(full_path)
|
||||
with open(full_path, 'rb') as f:
|
||||
assert f.read() == b''.join(chunks)
|
||||
os.remove(full_path)
|
||||
|
||||
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
||||
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
||||
|
||||
# Verify the first and last calls
|
||||
first_call = mock_callback.call_args_list[0]
|
||||
assert first_call[0][1].status == "in_progress"
|
||||
# Allow for some initial progress, but it should be less than 50%
|
||||
assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"
|
||||
|
||||
last_call = mock_callback.call_args_list[-1]
|
||||
assert last_call[0][1].status == "completed"
|
||||
assert last_call[0][1].progress_percentage == 100
|
||||
|
||||
@pytest.mark.parametrize("filename, expected", [
|
||||
("valid_model.safetensors", True),
|
||||
("valid_model.sft", True),
|
||||
("valid model.safetensors", True), # Test with space
|
||||
("UPPERCASE_MODEL.SAFETENSORS", True),
|
||||
("model_with.multiple.dots.pt", False),
|
||||
("", False), # Empty string
|
||||
("../../../etc/passwd", False), # Path traversal attempt
|
||||
("/etc/passwd", False), # Absolute path
|
||||
("\\windows\\system32\\config\\sam", False), # Windows path
|
||||
(".hidden_file.pt", False), # Hidden file
|
||||
("invalid<char>.ckpt", False), # Invalid character
|
||||
("invalid?.ckpt", False), # Another invalid character
|
||||
("very" * 100 + ".safetensors", False), # Too long filename
|
||||
("\nmodel_with_newline.pt", False), # Newline character
|
||||
("model_with_emoji😊.pt", False), # Emoji in filename
|
||||
])
|
||||
def test_validate_filename(filename, expected):
|
||||
assert validate_filename(filename) == expected
|
||||
@@ -14,7 +14,7 @@ def user_manager(tmp_path):
|
||||
um = UserManager()
|
||||
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
|
||||
tmp_path, file
|
||||
)
|
||||
) if file else tmp_path
|
||||
return um
|
||||
|
||||
|
||||
@@ -80,9 +80,7 @@ async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
|
||||
assert resp.status == 200
|
||||
assert await resp.json() == [
|
||||
["subdir/file1.txt", "subdir", "file1.txt"]
|
||||
]
|
||||
assert await resp.json() == [["subdir/file1.txt", "subdir", "file1.txt"]]
|
||||
|
||||
|
||||
async def test_listuserdata_invalid_directory(aiohttp_client, app):
|
||||
@@ -118,3 +116,116 @@ async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
|
||||
assert "/" in result[0]["path"] # Ensure forward slash is used
|
||||
assert "\\" not in result[0]["path"] # Ensure backslash is not present
|
||||
assert result[0]["path"] == "subdir/file1.txt"
|
||||
|
||||
|
||||
async def test_post_userdata_new_file(aiohttp_client, app, tmp_path):
|
||||
client = await aiohttp_client(app)
|
||||
content = b"test content"
|
||||
resp = await client.post("/userdata/test.txt", data=content)
|
||||
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == '"test.txt"'
|
||||
|
||||
# Verify file was created with correct content
|
||||
with open(tmp_path / "test.txt", "rb") as f:
|
||||
assert f.read() == content
|
||||
|
||||
|
||||
async def test_post_userdata_overwrite_existing(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "test.txt", "w") as f:
|
||||
f.write("initial content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
new_content = b"updated content"
|
||||
resp = await client.post("/userdata/test.txt", data=new_content)
|
||||
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == '"test.txt"'
|
||||
|
||||
# Verify file was overwritten
|
||||
with open(tmp_path / "test.txt", "rb") as f:
|
||||
assert f.read() == new_content
|
||||
|
||||
|
||||
async def test_post_userdata_no_overwrite(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "test.txt", "w") as f:
|
||||
f.write("initial content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/test.txt?overwrite=false", data=b"new content")
|
||||
|
||||
assert resp.status == 409
|
||||
|
||||
# Verify original content unchanged
|
||||
with open(tmp_path / "test.txt", "r") as f:
|
||||
assert f.read() == "initial content"
|
||||
|
||||
|
||||
async def test_post_userdata_full_info(aiohttp_client, app, tmp_path):
|
||||
client = await aiohttp_client(app)
|
||||
content = b"test content"
|
||||
resp = await client.post("/userdata/test.txt?full_info=true", data=content)
|
||||
|
||||
assert resp.status == 200
|
||||
result = await resp.json()
|
||||
assert result["path"] == "test.txt"
|
||||
assert result["size"] == len(content)
|
||||
assert "modified" in result
|
||||
|
||||
|
||||
async def test_move_userdata(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "source.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/source.txt/move/dest.txt")
|
||||
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == '"dest.txt"'
|
||||
|
||||
# Verify file was moved
|
||||
assert not os.path.exists(tmp_path / "source.txt")
|
||||
with open(tmp_path / "dest.txt", "r") as f:
|
||||
assert f.read() == "test content"
|
||||
|
||||
|
||||
async def test_move_userdata_no_overwrite(aiohttp_client, app, tmp_path):
|
||||
# Create source and destination files
|
||||
with open(tmp_path / "source.txt", "w") as f:
|
||||
f.write("source content")
|
||||
with open(tmp_path / "dest.txt", "w") as f:
|
||||
f.write("destination content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/source.txt/move/dest.txt?overwrite=false")
|
||||
|
||||
assert resp.status == 409
|
||||
|
||||
# Verify files remain unchanged
|
||||
with open(tmp_path / "source.txt", "r") as f:
|
||||
assert f.read() == "source content"
|
||||
with open(tmp_path / "dest.txt", "r") as f:
|
||||
assert f.read() == "destination content"
|
||||
|
||||
|
||||
async def test_move_userdata_full_info(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "source.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/source.txt/move/dest.txt?full_info=true")
|
||||
|
||||
assert resp.status == 200
|
||||
result = await resp.json()
|
||||
assert result["path"] == "dest.txt"
|
||||
assert result["size"] == len("test content")
|
||||
assert "modified" in result
|
||||
|
||||
# Verify file was moved
|
||||
assert not os.path.exists(tmp_path / "source.txt")
|
||||
with open(tmp_path / "dest.txt", "r") as f:
|
||||
assert f.read() == "test content"
|
||||
|
||||
@@ -8,7 +8,7 @@ from folder_paths import models_dir, user_directory, output_directory
|
||||
|
||||
@pytest.fixture
|
||||
def internal_routes():
|
||||
return InternalRoutes()
|
||||
return InternalRoutes(None)
|
||||
|
||||
@pytest.fixture
|
||||
def aiohttp_client_factory(aiohttp_client, internal_routes):
|
||||
@@ -102,7 +102,7 @@ async def test_file_service_initialization():
|
||||
# Create a mock instance
|
||||
mock_file_service_instance = MagicMock(spec=FileService)
|
||||
MockFileService.return_value = mock_file_service_instance
|
||||
internal_routes = InternalRoutes()
|
||||
internal_routes = InternalRoutes(None)
|
||||
|
||||
# Check if FileService was initialized with the correct parameters
|
||||
MockFileService.assert_called_once_with({
|
||||
@@ -112,4 +112,4 @@ async def test_file_service_initialization():
|
||||
})
|
||||
|
||||
# Verify that the file_service attribute of InternalRoutes is set
|
||||
assert internal_routes.file_service == mock_file_service_instance
|
||||
assert internal_routes.file_service == mock_file_service_instance
|
||||
|
||||
@@ -24,5 +24,8 @@ def load_extra_path_config(yaml_path):
|
||||
full_path = y
|
||||
if base_path is not None:
|
||||
full_path = os.path.join(base_path, full_path)
|
||||
elif not os.path.isabs(full_path):
|
||||
yaml_dir = os.path.dirname(os.path.abspath(yaml_path))
|
||||
full_path = os.path.abspath(os.path.join(yaml_dir, y))
|
||||
logging.info("Adding extra search path {} {}".format(x, full_path))
|
||||
folder_paths.add_model_folder_path(x, full_path, is_default)
|
||||
|
||||
103
web/assets/ExtensionPanel-BmKi_NKS.js
generated
vendored
103
web/assets/ExtensionPanel-BmKi_NKS.js
generated
vendored
@@ -1,103 +0,0 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, bQ as useExtensionStore, u as useSettingStore, r as ref, o as onMounted, q as computed, g as openBlock, h as createElementBlock, i as createVNode, y as withCtx, z as unref, bR as script$1, A as createBaseVNode, x as createBlock, N as Fragment, O as renderList, a6 as toDisplayString, aw as createTextVNode, j as createCommentVNode, D as script$4 } from "./index-BHayQCxv.js";
|
||||
import { s as script, a as script$2, b as script$3 } from "./index-CwRXxFdA.js";
|
||||
import "./index-C_wOqB0f.js";
|
||||
const _hoisted_1 = { class: "extension-panel" };
|
||||
const _hoisted_2 = { class: "mt-4" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "ExtensionPanel",
|
||||
setup(__props) {
|
||||
const extensionStore = useExtensionStore();
|
||||
const settingStore = useSettingStore();
|
||||
const editingEnabledExtensions = ref({});
|
||||
onMounted(() => {
|
||||
extensionStore.extensions.forEach((ext) => {
|
||||
editingEnabledExtensions.value[ext.name] = extensionStore.isExtensionEnabled(ext.name);
|
||||
});
|
||||
});
|
||||
const changedExtensions = computed(() => {
|
||||
return extensionStore.extensions.filter(
|
||||
(ext) => editingEnabledExtensions.value[ext.name] !== extensionStore.isExtensionEnabled(ext.name)
|
||||
);
|
||||
});
|
||||
const hasChanges = computed(() => {
|
||||
return changedExtensions.value.length > 0;
|
||||
});
|
||||
const updateExtensionStatus = /* @__PURE__ */ __name(() => {
|
||||
const editingDisabledExtensionNames = Object.entries(
|
||||
editingEnabledExtensions.value
|
||||
).filter(([_, enabled]) => !enabled).map(([name]) => name);
|
||||
settingStore.set("Comfy.Extension.Disabled", [
|
||||
...extensionStore.inactiveDisabledExtensionNames,
|
||||
...editingDisabledExtensionNames
|
||||
]);
|
||||
}, "updateExtensionStatus");
|
||||
const applyChanges = /* @__PURE__ */ __name(() => {
|
||||
window.location.reload();
|
||||
}, "applyChanges");
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("div", _hoisted_1, [
|
||||
createVNode(unref(script$2), {
|
||||
value: unref(extensionStore).extensions,
|
||||
stripedRows: "",
|
||||
size: "small"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createVNode(unref(script), {
|
||||
field: "name",
|
||||
header: _ctx.$t("extensionName"),
|
||||
sortable: ""
|
||||
}, null, 8, ["header"]),
|
||||
createVNode(unref(script), { pt: {
|
||||
bodyCell: "flex items-center justify-end"
|
||||
} }, {
|
||||
body: withCtx((slotProps) => [
|
||||
createVNode(unref(script$1), {
|
||||
modelValue: editingEnabledExtensions.value[slotProps.data.name],
|
||||
"onUpdate:modelValue": /* @__PURE__ */ __name(($event) => editingEnabledExtensions.value[slotProps.data.name] = $event, "onUpdate:modelValue"),
|
||||
onChange: updateExtensionStatus
|
||||
}, null, 8, ["modelValue", "onUpdate:modelValue"])
|
||||
]),
|
||||
_: 1
|
||||
})
|
||||
]),
|
||||
_: 1
|
||||
}, 8, ["value"]),
|
||||
createBaseVNode("div", _hoisted_2, [
|
||||
hasChanges.value ? (openBlock(), createBlock(unref(script$3), {
|
||||
key: 0,
|
||||
severity: "info"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createBaseVNode("ul", null, [
|
||||
(openBlock(true), createElementBlock(Fragment, null, renderList(changedExtensions.value, (ext) => {
|
||||
return openBlock(), createElementBlock("li", {
|
||||
key: ext.name
|
||||
}, [
|
||||
createBaseVNode("span", null, toDisplayString(unref(extensionStore).isExtensionEnabled(ext.name) ? "[-]" : "[+]"), 1),
|
||||
createTextVNode(" " + toDisplayString(ext.name), 1)
|
||||
]);
|
||||
}), 128))
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
})) : createCommentVNode("", true),
|
||||
createVNode(unref(script$4), {
|
||||
label: _ctx.$t("reloadToApplyChanges"),
|
||||
icon: "pi pi-refresh",
|
||||
onClick: applyChanges,
|
||||
disabled: !hasChanges.value,
|
||||
text: "",
|
||||
fluid: "",
|
||||
severity: "danger"
|
||||
}, null, 8, ["label", "disabled"])
|
||||
])
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=ExtensionPanel-BmKi_NKS.js.map
|
||||
1
web/assets/ExtensionPanel-BmKi_NKS.js.map
generated
vendored
1
web/assets/ExtensionPanel-BmKi_NKS.js.map
generated
vendored
@@ -1 +0,0 @@
|
||||
{"version":3,"file":"ExtensionPanel-BmKi_NKS.js","sources":["../../src/components/dialog/content/setting/ExtensionPanel.vue"],"sourcesContent":["<template>\n <div class=\"extension-panel\">\n <DataTable :value=\"extensionStore.extensions\" stripedRows size=\"small\">\n <Column field=\"name\" :header=\"$t('extensionName')\" sortable></Column>\n <Column\n :pt=\"{\n bodyCell: 'flex items-center justify-end'\n }\"\n >\n <template #body=\"slotProps\">\n <ToggleSwitch\n v-model=\"editingEnabledExtensions[slotProps.data.name]\"\n @change=\"updateExtensionStatus\"\n />\n </template>\n </Column>\n </DataTable>\n <div class=\"mt-4\">\n <Message v-if=\"hasChanges\" severity=\"info\">\n <ul>\n <li v-for=\"ext in changedExtensions\" :key=\"ext.name\">\n <span>\n {{ extensionStore.isExtensionEnabled(ext.name) ? '[-]' : '[+]' }}\n </span>\n {{ ext.name }}\n </li>\n </ul>\n </Message>\n <Button\n :label=\"$t('reloadToApplyChanges')\"\n icon=\"pi pi-refresh\"\n @click=\"applyChanges\"\n :disabled=\"!hasChanges\"\n text\n fluid\n severity=\"danger\"\n />\n </div>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, computed, onMounted } from 'vue'\nimport { useExtensionStore } from '@/stores/extensionStore'\nimport { useSettingStore } from '@/stores/settingStore'\nimport DataTable from 'primevue/datatable'\nimport Column from 'primevue/column'\nimport ToggleSwitch from 'primevue/toggleswitch'\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\n\nconst extensionStore = useExtensionStore()\nconst settingStore = useSettingStore()\n\nconst editingEnabledExtensions = ref<Record<string, boolean>>({})\n\nonMounted(() => {\n extensionStore.extensions.forEach((ext) => {\n editingEnabledExtensions.value[ext.name] =\n extensionStore.isExtensionEnabled(ext.name)\n })\n})\n\nconst changedExtensions = computed(() => {\n return extensionStore.extensions.filter(\n (ext) =>\n editingEnabledExtensions.value[ext.name] !==\n extensionStore.isExtensionEnabled(ext.name)\n )\n})\n\nconst hasChanges = computed(() => {\n return changedExtensions.value.length > 0\n})\n\nconst updateExtensionStatus = () => {\n const editingDisabledExtensionNames = Object.entries(\n editingEnabledExtensions.value\n )\n .filter(([_, enabled]) => !enabled)\n .map(([name]) => name)\n\n settingStore.set('Comfy.Extension.Disabled', [\n ...extensionStore.inactiveDisabledExtensionNames,\n ...editingDisabledExtensionNames\n ])\n}\n\nconst applyChanges = () => {\n // Refresh the page to apply changes\n window.location.reload()\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;AAmDA,UAAM,iBAAiB;AACvB,UAAM,eAAe;AAEf,UAAA,2BAA2B,IAA6B,CAAA,CAAE;AAEhE,cAAU,MAAM;AACC,qBAAA,WAAW,QAAQ,CAAC,QAAQ;AACzC,iCAAyB,MAAM,IAAI,IAAI,IACrC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA,CAC7C;AAAA,IAAA,CACF;AAEK,UAAA,oBAAoB,SAAS,MAAM;AACvC,aAAO,eAAe,WAAW;AAAA,QAC/B,CAAC,QACC,yBAAyB,MAAM,IAAI,IAAI,MACvC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA;AAAA,IAC9C,CACD;AAEK,UAAA,aAAa,SAAS,MAAM;AACzB,aAAA,kBAAkB,MAAM,SAAS;AAAA,IAAA,CACzC;AAED,UAAM,wBAAwB,6BAAM;AAClC,YAAM,gCAAgC,OAAO;AAAA,QAC3C,yBAAyB;AAAA,MAExB,EAAA,OAAO,CAAC,CAAC,GAAG,OAAO,MAAM,CAAC,OAAO,EACjC,IAAI,CAAC,CAAC,IAAI,MAAM,IAAI;AAEvB,mBAAa,IAAI,4BAA4B;AAAA,QAC3C,GAAG,eAAe;AAAA,QAClB,GAAG;AAAA,MAAA,CACJ;AAAA,IAAA,GAV2B;AAa9B,UAAM,eAAe,6BAAM;AAEzB,aAAO,SAAS;IAAO,GAFJ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||
117
web/assets/ExtensionPanel-DsD42OtO.js
generated
vendored
Normal file
117
web/assets/ExtensionPanel-DsD42OtO.js
generated
vendored
Normal file
@@ -0,0 +1,117 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, r as ref, c6 as FilterMatchMode, ca as useExtensionStore, u as useSettingStore, o as onMounted, q as computed, g as openBlock, x as createBlock, y as withCtx, i as createVNode, c7 as SearchBox, z as unref, bT as script, A as createBaseVNode, h as createElementBlock, O as renderList, a6 as toDisplayString, aw as createTextVNode, N as Fragment, D as script$1, j as createCommentVNode, bV as script$3, c8 as _sfc_main$1 } from "./index-CoOvI8ZH.js";
|
||||
import { s as script$2, a as script$4 } from "./index-DK6Kev7f.js";
|
||||
import "./index-D4DWQPPQ.js";
|
||||
const _hoisted_1 = { class: "flex justify-end" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "ExtensionPanel",
|
||||
setup(__props) {
|
||||
const filters = ref({
|
||||
global: { value: "", matchMode: FilterMatchMode.CONTAINS }
|
||||
});
|
||||
const extensionStore = useExtensionStore();
|
||||
const settingStore = useSettingStore();
|
||||
const editingEnabledExtensions = ref({});
|
||||
onMounted(() => {
|
||||
extensionStore.extensions.forEach((ext) => {
|
||||
editingEnabledExtensions.value[ext.name] = extensionStore.isExtensionEnabled(ext.name);
|
||||
});
|
||||
});
|
||||
const changedExtensions = computed(() => {
|
||||
return extensionStore.extensions.filter(
|
||||
(ext) => editingEnabledExtensions.value[ext.name] !== extensionStore.isExtensionEnabled(ext.name)
|
||||
);
|
||||
});
|
||||
const hasChanges = computed(() => {
|
||||
return changedExtensions.value.length > 0;
|
||||
});
|
||||
const updateExtensionStatus = /* @__PURE__ */ __name(() => {
|
||||
const editingDisabledExtensionNames = Object.entries(
|
||||
editingEnabledExtensions.value
|
||||
).filter(([_, enabled]) => !enabled).map(([name]) => name);
|
||||
settingStore.set("Comfy.Extension.Disabled", [
|
||||
...extensionStore.inactiveDisabledExtensionNames,
|
||||
...editingDisabledExtensionNames
|
||||
]);
|
||||
}, "updateExtensionStatus");
|
||||
const applyChanges = /* @__PURE__ */ __name(() => {
|
||||
window.location.reload();
|
||||
}, "applyChanges");
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createBlock(_sfc_main$1, {
|
||||
value: "Extension",
|
||||
class: "extension-panel"
|
||||
}, {
|
||||
header: withCtx(() => [
|
||||
createVNode(SearchBox, {
|
||||
modelValue: filters.value["global"].value,
|
||||
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => filters.value["global"].value = $event),
|
||||
placeholder: _ctx.$t("searchExtensions") + "..."
|
||||
}, null, 8, ["modelValue", "placeholder"]),
|
||||
hasChanges.value ? (openBlock(), createBlock(unref(script), {
|
||||
key: 0,
|
||||
severity: "info",
|
||||
"pt:text": "w-full"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createBaseVNode("ul", null, [
|
||||
(openBlock(true), createElementBlock(Fragment, null, renderList(changedExtensions.value, (ext) => {
|
||||
return openBlock(), createElementBlock("li", {
|
||||
key: ext.name
|
||||
}, [
|
||||
createBaseVNode("span", null, toDisplayString(unref(extensionStore).isExtensionEnabled(ext.name) ? "[-]" : "[+]"), 1),
|
||||
createTextVNode(" " + toDisplayString(ext.name), 1)
|
||||
]);
|
||||
}), 128))
|
||||
]),
|
||||
createBaseVNode("div", _hoisted_1, [
|
||||
createVNode(unref(script$1), {
|
||||
label: _ctx.$t("reloadToApplyChanges"),
|
||||
onClick: applyChanges,
|
||||
outlined: "",
|
||||
severity: "danger"
|
||||
}, null, 8, ["label"])
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
})) : createCommentVNode("", true)
|
||||
]),
|
||||
default: withCtx(() => [
|
||||
createVNode(unref(script$4), {
|
||||
value: unref(extensionStore).extensions,
|
||||
stripedRows: "",
|
||||
size: "small",
|
||||
filters: filters.value
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createVNode(unref(script$2), {
|
||||
field: "name",
|
||||
header: _ctx.$t("extensionName"),
|
||||
sortable: ""
|
||||
}, null, 8, ["header"]),
|
||||
createVNode(unref(script$2), { pt: {
|
||||
bodyCell: "flex items-center justify-end"
|
||||
} }, {
|
||||
body: withCtx((slotProps) => [
|
||||
createVNode(unref(script$3), {
|
||||
modelValue: editingEnabledExtensions.value[slotProps.data.name],
|
||||
"onUpdate:modelValue": /* @__PURE__ */ __name(($event) => editingEnabledExtensions.value[slotProps.data.name] = $event, "onUpdate:modelValue"),
|
||||
onChange: updateExtensionStatus
|
||||
}, null, 8, ["modelValue", "onUpdate:modelValue"])
|
||||
]),
|
||||
_: 1
|
||||
})
|
||||
]),
|
||||
_: 1
|
||||
}, 8, ["value", "filters"])
|
||||
]),
|
||||
_: 1
|
||||
});
|
||||
};
|
||||
}
|
||||
});
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=ExtensionPanel-DsD42OtO.js.map
|
||||
1
web/assets/ExtensionPanel-DsD42OtO.js.map
generated
vendored
Normal file
1
web/assets/ExtensionPanel-DsD42OtO.js.map
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
{"version":3,"file":"ExtensionPanel-DsD42OtO.js","sources":["../../src/components/dialog/content/setting/ExtensionPanel.vue"],"sourcesContent":["<template>\n <PanelTemplate value=\"Extension\" class=\"extension-panel\">\n <template #header>\n <SearchBox\n v-model=\"filters['global'].value\"\n :placeholder=\"$t('searchExtensions') + '...'\"\n />\n <Message v-if=\"hasChanges\" severity=\"info\" pt:text=\"w-full\">\n <ul>\n <li v-for=\"ext in changedExtensions\" :key=\"ext.name\">\n <span>\n {{ extensionStore.isExtensionEnabled(ext.name) ? '[-]' : '[+]' }}\n </span>\n {{ ext.name }}\n </li>\n </ul>\n <div class=\"flex justify-end\">\n <Button\n :label=\"$t('reloadToApplyChanges')\"\n @click=\"applyChanges\"\n outlined\n severity=\"danger\"\n />\n </div>\n </Message>\n </template>\n <DataTable\n :value=\"extensionStore.extensions\"\n stripedRows\n size=\"small\"\n :filters=\"filters\"\n >\n <Column field=\"name\" :header=\"$t('extensionName')\" sortable></Column>\n <Column\n :pt=\"{\n bodyCell: 'flex items-center justify-end'\n }\"\n >\n <template #body=\"slotProps\">\n <ToggleSwitch\n v-model=\"editingEnabledExtensions[slotProps.data.name]\"\n @change=\"updateExtensionStatus\"\n />\n </template>\n </Column>\n </DataTable>\n </PanelTemplate>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, computed, onMounted } from 'vue'\nimport { useExtensionStore } from '@/stores/extensionStore'\nimport { useSettingStore } from '@/stores/settingStore'\nimport DataTable from 'primevue/datatable'\nimport Column from 'primevue/column'\nimport ToggleSwitch from 'primevue/toggleswitch'\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\nimport { FilterMatchMode } from '@primevue/core/api'\nimport PanelTemplate from './PanelTemplate.vue'\nimport SearchBox from '@/components/common/SearchBox.vue'\n\nconst filters = ref({\n global: { value: '', matchMode: FilterMatchMode.CONTAINS }\n})\n\nconst extensionStore = useExtensionStore()\nconst settingStore = useSettingStore()\n\nconst editingEnabledExtensions = ref<Record<string, boolean>>({})\n\nonMounted(() => {\n extensionStore.extensions.forEach((ext) => {\n editingEnabledExtensions.value[ext.name] =\n extensionStore.isExtensionEnabled(ext.name)\n })\n})\n\nconst changedExtensions = computed(() => {\n return extensionStore.extensions.filter(\n (ext) =>\n editingEnabledExtensions.value[ext.name] !==\n extensionStore.isExtensionEnabled(ext.name)\n )\n})\n\nconst hasChanges = computed(() => {\n return changedExtensions.value.length > 0\n})\n\nconst updateExtensionStatus = () => {\n const editingDisabledExtensionNames = Object.entries(\n editingEnabledExtensions.value\n )\n .filter(([_, enabled]) => !enabled)\n .map(([name]) => name)\n\n settingStore.set('Comfy.Extension.Disabled', [\n ...extensionStore.inactiveDisabledExtensionNames,\n ...editingDisabledExtensionNames\n ])\n}\n\nconst applyChanges = () => {\n // Refresh the page to apply changes\n window.location.reload()\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;AA8DA,UAAM,UAAU,IAAI;AAAA,MAClB,QAAQ,EAAE,OAAO,IAAI,WAAW,gBAAgB,SAAS;AAAA,IAAA,CAC1D;AAED,UAAM,iBAAiB;AACvB,UAAM,eAAe;AAEf,UAAA,2BAA2B,IAA6B,CAAA,CAAE;AAEhE,cAAU,MAAM;AACC,qBAAA,WAAW,QAAQ,CAAC,QAAQ;AACzC,iCAAyB,MAAM,IAAI,IAAI,IACrC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA,CAC7C;AAAA,IAAA,CACF;AAEK,UAAA,oBAAoB,SAAS,MAAM;AACvC,aAAO,eAAe,WAAW;AAAA,QAC/B,CAAC,QACC,yBAAyB,MAAM,IAAI,IAAI,MACvC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA;AAAA,IAC9C,CACD;AAEK,UAAA,aAAa,SAAS,MAAM;AACzB,aAAA,kBAAkB,MAAM,SAAS;AAAA,IAAA,CACzC;AAED,UAAM,wBAAwB,6BAAM;AAClC,YAAM,gCAAgC,OAAO;AAAA,QAC3C,yBAAyB;AAAA,MAExB,EAAA,OAAO,CAAC,CAAC,GAAG,OAAO,MAAM,CAAC,OAAO,EACjC,IAAI,CAAC,CAAC,IAAI,MAAM,IAAI;AAEvB,mBAAa,IAAI,4BAA4B;AAAA,QAC3C,GAAG,eAAe;AAAA,QAClB,GAAG;AAAA,MAAA,CACJ;AAAA,IAAA,GAV2B;AAa9B,UAAM,eAAe,6BAAM;AAEzB,aAAO,SAAS;IAAO,GAFJ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||
1175
web/assets/GraphView-C4blCugc.js → web/assets/GraphView-BW5soyxY.js
generated
vendored
1175
web/assets/GraphView-C4blCugc.js → web/assets/GraphView-BW5soyxY.js
generated
vendored
File diff suppressed because one or more lines are too long
1
web/assets/GraphView-BW5soyxY.js.map
generated
vendored
Normal file
1
web/assets/GraphView-BW5soyxY.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
web/assets/GraphView-C4blCugc.js.map
generated
vendored
1
web/assets/GraphView-C4blCugc.js.map
generated
vendored
File diff suppressed because one or more lines are too long
90
web/assets/GraphView-Cf7ubG48.css → web/assets/GraphView-DtkYXy38.css
generated
vendored
90
web/assets/GraphView-Cf7ubG48.css → web/assets/GraphView-DtkYXy38.css
generated
vendored
@@ -45,7 +45,7 @@
|
||||
--sidebar-icon-size: 1rem;
|
||||
}
|
||||
|
||||
.side-tool-bar-container[data-v-37fd2fa4] {
|
||||
.side-tool-bar-container[data-v-e0812a25] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
@@ -58,28 +58,32 @@
|
||||
background-color: var(--comfy-menu-bg);
|
||||
color: var(--fg-color);
|
||||
}
|
||||
.side-tool-bar-end[data-v-37fd2fa4] {
|
||||
.side-tool-bar-end[data-v-e0812a25] {
|
||||
align-self: flex-end;
|
||||
margin-top: auto;
|
||||
}
|
||||
|
||||
[data-v-b49f20b1] .p-splitter-gutter {
|
||||
[data-v-7c3279c1] .p-splitter-gutter {
|
||||
pointer-events: auto;
|
||||
}
|
||||
.side-bar-panel[data-v-b49f20b1] {
|
||||
[data-v-7c3279c1] .p-splitter-gutter:hover,[data-v-7c3279c1] .p-splitter-gutter[data-p-gutter-resizing='true'] {
|
||||
transition: background-color 0.2s ease 300ms;
|
||||
background-color: var(--p-primary-color);
|
||||
}
|
||||
.side-bar-panel[data-v-7c3279c1] {
|
||||
background-color: var(--bg-color);
|
||||
pointer-events: auto;
|
||||
}
|
||||
.bottom-panel[data-v-b49f20b1] {
|
||||
.bottom-panel[data-v-7c3279c1] {
|
||||
background-color: var(--bg-color);
|
||||
pointer-events: auto;
|
||||
}
|
||||
.splitter-overlay[data-v-b49f20b1] {
|
||||
.splitter-overlay[data-v-7c3279c1] {
|
||||
pointer-events: none;
|
||||
border-style: none;
|
||||
background-color: transparent;
|
||||
}
|
||||
.splitter-overlay-root[data-v-b49f20b1] {
|
||||
.splitter-overlay-root[data-v-7c3279c1] {
|
||||
position: absolute;
|
||||
top: 0px;
|
||||
left: 0px;
|
||||
@@ -102,32 +106,6 @@
|
||||
margin: -0.125rem 0.125rem;
|
||||
}
|
||||
|
||||
.comfy-vue-node-search-container[data-v-2d409367] {
|
||||
display: flex;
|
||||
width: 100%;
|
||||
min-width: 26rem;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
.comfy-vue-node-search-container[data-v-2d409367] * {
|
||||
pointer-events: auto;
|
||||
}
|
||||
.comfy-vue-node-preview-container[data-v-2d409367] {
|
||||
position: absolute;
|
||||
left: -350px;
|
||||
top: 50px;
|
||||
}
|
||||
.comfy-vue-node-search-box[data-v-2d409367] {
|
||||
z-index: 10;
|
||||
flex-grow: 1;
|
||||
}
|
||||
._filter-button[data-v-2d409367] {
|
||||
z-index: 10;
|
||||
}
|
||||
._dialog[data-v-2d409367] {
|
||||
min-width: 26rem;
|
||||
}
|
||||
|
||||
.invisible-dialog-root {
|
||||
width: 60%;
|
||||
min-width: 24rem;
|
||||
@@ -146,7 +124,7 @@
|
||||
align-items: flex-start !important;
|
||||
}
|
||||
|
||||
.node-tooltip[data-v-79ec8c53] {
|
||||
.node-tooltip[data-v-c2e0098f] {
|
||||
background: var(--comfy-input-bg);
|
||||
border-radius: 5px;
|
||||
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
|
||||
@@ -162,22 +140,28 @@
|
||||
z-index: 99999;
|
||||
}
|
||||
|
||||
.p-buttongroup-vertical[data-v-444d3768] {
|
||||
.p-buttongroup-vertical[data-v-94481f39] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
border-radius: var(--p-button-border-radius);
|
||||
overflow: hidden;
|
||||
border: 1px solid var(--p-panel-border-color);
|
||||
}
|
||||
.p-buttongroup-vertical .p-button[data-v-444d3768] {
|
||||
.p-buttongroup-vertical .p-button[data-v-94481f39] {
|
||||
margin: 0;
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
[data-v-84e785b8] .p-togglebutton::before {
|
||||
.comfy-menu-hamburger[data-v-2ddd26e8] {
|
||||
pointer-events: auto;
|
||||
position: fixed;
|
||||
z-index: 9999;
|
||||
}
|
||||
|
||||
[data-v-783f8efe] .p-togglebutton::before {
|
||||
display: none
|
||||
}
|
||||
[data-v-84e785b8] .p-togglebutton {
|
||||
[data-v-783f8efe] .p-togglebutton {
|
||||
position: relative;
|
||||
flex-shrink: 0;
|
||||
border-radius: 0px;
|
||||
@@ -185,14 +169,14 @@
|
||||
padding-left: 0.5rem;
|
||||
padding-right: 0.5rem
|
||||
}
|
||||
[data-v-84e785b8] .p-togglebutton.p-togglebutton-checked {
|
||||
[data-v-783f8efe] .p-togglebutton.p-togglebutton-checked {
|
||||
border-bottom-width: 2px;
|
||||
border-bottom-color: var(--p-button-text-primary-color)
|
||||
}
|
||||
[data-v-84e785b8] .p-togglebutton-checked .close-button,[data-v-84e785b8] .p-togglebutton:hover .close-button {
|
||||
[data-v-783f8efe] .p-togglebutton-checked .close-button,[data-v-783f8efe] .p-togglebutton:hover .close-button {
|
||||
visibility: visible
|
||||
}
|
||||
.status-indicator[data-v-84e785b8] {
|
||||
.status-indicator[data-v-783f8efe] {
|
||||
position: absolute;
|
||||
font-weight: 700;
|
||||
font-size: 1.5rem;
|
||||
@@ -200,10 +184,10 @@
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%)
|
||||
}
|
||||
[data-v-84e785b8] .p-togglebutton:hover .status-indicator {
|
||||
[data-v-783f8efe] .p-togglebutton:hover .status-indicator {
|
||||
display: none
|
||||
}
|
||||
[data-v-84e785b8] .p-togglebutton .close-button {
|
||||
[data-v-783f8efe] .p-togglebutton .close-button {
|
||||
visibility: hidden
|
||||
}
|
||||
|
||||
@@ -226,35 +210,35 @@
|
||||
border-bottom-left-radius: 0;
|
||||
}
|
||||
|
||||
.comfyui-queue-button[data-v-2b80bf74] .p-splitbutton-dropdown {
|
||||
.comfyui-queue-button[data-v-95bc9be0] .p-splitbutton-dropdown {
|
||||
border-top-right-radius: 0;
|
||||
border-bottom-right-radius: 0;
|
||||
}
|
||||
|
||||
.actionbar[data-v-2e54db00] {
|
||||
.actionbar[data-v-542a7001] {
|
||||
pointer-events: all;
|
||||
position: fixed;
|
||||
z-index: 1000;
|
||||
}
|
||||
.actionbar.is-docked[data-v-2e54db00] {
|
||||
.actionbar.is-docked[data-v-542a7001] {
|
||||
position: static;
|
||||
border-style: none;
|
||||
background-color: transparent;
|
||||
padding: 0px;
|
||||
}
|
||||
.actionbar.is-dragging[data-v-2e54db00] {
|
||||
.actionbar.is-dragging[data-v-542a7001] {
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
user-select: none;
|
||||
}
|
||||
[data-v-2e54db00] .p-panel-content {
|
||||
[data-v-542a7001] .p-panel-content {
|
||||
padding: 0.25rem;
|
||||
}
|
||||
[data-v-2e54db00] .p-panel-header {
|
||||
[data-v-542a7001] .p-panel-header {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.comfyui-menu[data-v-ad2c662b] {
|
||||
.comfyui-menu[data-v-d84a704d] {
|
||||
width: 100vw;
|
||||
background: var(--comfy-menu-bg);
|
||||
color: var(--fg-color);
|
||||
@@ -266,13 +250,13 @@
|
||||
grid-column: 1/-1;
|
||||
max-height: 90vh;
|
||||
}
|
||||
.comfyui-menu.dropzone[data-v-ad2c662b] {
|
||||
.comfyui-menu.dropzone[data-v-d84a704d] {
|
||||
background: var(--p-highlight-background);
|
||||
}
|
||||
.comfyui-menu.dropzone-active[data-v-ad2c662b] {
|
||||
.comfyui-menu.dropzone-active[data-v-d84a704d] {
|
||||
background: var(--p-highlight-background-focus);
|
||||
}
|
||||
.comfyui-logo[data-v-ad2c662b] {
|
||||
.comfyui-logo[data-v-d84a704d] {
|
||||
font-size: 1.2em;
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
1074
web/assets/InstallView-C6UIhIu4.js
generated
vendored
Normal file
1074
web/assets/InstallView-C6UIhIu4.js
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
web/assets/InstallView-C6UIhIu4.js.map
generated
vendored
Normal file
1
web/assets/InstallView-C6UIhIu4.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
4
web/assets/InstallView-CN3CA9Fk.css
generated
vendored
Normal file
4
web/assets/InstallView-CN3CA9Fk.css
generated
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
|
||||
[data-v-53e62b05] .p-steppanel {
|
||||
background-color: transparent
|
||||
}
|
||||
8
web/assets/KeybindingPanel-BNYKhW1k.css
generated
vendored
8
web/assets/KeybindingPanel-BNYKhW1k.css
generated
vendored
@@ -1,8 +0,0 @@
|
||||
|
||||
[data-v-e5724e4d] .p-datatable-tbody > tr > td {
|
||||
padding: 1px;
|
||||
min-height: 2rem;
|
||||
}
|
||||
[data-v-e5724e4d] .p-datatable-row-selected .actions,[data-v-e5724e4d] .p-datatable-selectable-row:hover .actions {
|
||||
visibility: visible;
|
||||
}
|
||||
8
web/assets/KeybindingPanel-C-7KE-Kw.css
generated
vendored
Normal file
8
web/assets/KeybindingPanel-C-7KE-Kw.css
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
|
||||
[data-v-9d7e362e] .p-datatable-tbody > tr > td {
|
||||
padding: 0.25rem;
|
||||
min-height: 2rem
|
||||
}
|
||||
[data-v-9d7e362e] .p-datatable-row-selected .actions,[data-v-9d7e362e] .p-datatable-selectable-row:hover .actions {
|
||||
visibility: visible
|
||||
}
|
||||
264
web/assets/KeybindingPanel-Dm_3sBT5.js
generated
vendored
264
web/assets/KeybindingPanel-Dm_3sBT5.js
generated
vendored
@@ -1,264 +0,0 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, q as computed, g as openBlock, h as createElementBlock, N as Fragment, O as renderList, i as createVNode, y as withCtx, aw as createTextVNode, a6 as toDisplayString, z as unref, aA as script, j as createCommentVNode, r as ref, bN as FilterMatchMode, M as useKeybindingStore, F as useCommandStore, aJ as watchEffect, b9 as useToast, t as resolveDirective, bO as SearchBox, A as createBaseVNode, D as script$2, x as createBlock, ao as script$4, be as withModifiers, aH as script$6, v as withDirectives, P as pushScopeId, Q as popScopeId, bJ as KeyComboImpl, bP as KeybindingImpl, _ as _export_sfc } from "./index-BHayQCxv.js";
|
||||
import { s as script$1, a as script$3, b as script$5 } from "./index-CwRXxFdA.js";
|
||||
import "./index-C_wOqB0f.js";
|
||||
const _hoisted_1$1 = {
|
||||
key: 0,
|
||||
class: "px-2"
|
||||
};
|
||||
const _sfc_main$1 = /* @__PURE__ */ defineComponent({
|
||||
__name: "KeyComboDisplay",
|
||||
props: {
|
||||
keyCombo: {},
|
||||
isModified: { type: Boolean, default: false }
|
||||
},
|
||||
setup(__props) {
|
||||
const props = __props;
|
||||
const keySequences = computed(() => props.keyCombo.getKeySequences());
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("span", null, [
|
||||
(openBlock(true), createElementBlock(Fragment, null, renderList(keySequences.value, (sequence, index) => {
|
||||
return openBlock(), createElementBlock(Fragment, { key: index }, [
|
||||
createVNode(unref(script), {
|
||||
severity: _ctx.isModified ? "info" : "secondary"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createTextVNode(toDisplayString(sequence), 1)
|
||||
]),
|
||||
_: 2
|
||||
}, 1032, ["severity"]),
|
||||
index < keySequences.value.length - 1 ? (openBlock(), createElementBlock("span", _hoisted_1$1, "+")) : createCommentVNode("", true)
|
||||
], 64);
|
||||
}), 128))
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-e5724e4d"), n = n(), popScopeId(), n), "_withScopeId");
|
||||
const _hoisted_1 = { class: "keybinding-panel" };
|
||||
const _hoisted_2 = { class: "actions invisible" };
|
||||
const _hoisted_3 = { key: 1 };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "KeybindingPanel",
|
||||
setup(__props) {
|
||||
const filters = ref({
|
||||
global: { value: "", matchMode: FilterMatchMode.CONTAINS }
|
||||
});
|
||||
const keybindingStore = useKeybindingStore();
|
||||
const commandStore = useCommandStore();
|
||||
const commandsData = computed(() => {
|
||||
return Object.values(commandStore.commands).map((command) => ({
|
||||
id: command.id,
|
||||
keybinding: keybindingStore.getKeybindingByCommandId(command.id)
|
||||
}));
|
||||
});
|
||||
const selectedCommandData = ref(null);
|
||||
const editDialogVisible = ref(false);
|
||||
const newBindingKeyCombo = ref(null);
|
||||
const currentEditingCommand = ref(null);
|
||||
const keybindingInput = ref(null);
|
||||
const existingKeybindingOnCombo = computed(() => {
|
||||
if (!currentEditingCommand.value) {
|
||||
return null;
|
||||
}
|
||||
if (currentEditingCommand.value.keybinding?.combo?.equals(
|
||||
newBindingKeyCombo.value
|
||||
)) {
|
||||
return null;
|
||||
}
|
||||
if (!newBindingKeyCombo.value) {
|
||||
return null;
|
||||
}
|
||||
return keybindingStore.getKeybinding(newBindingKeyCombo.value);
|
||||
});
|
||||
function editKeybinding(commandData) {
|
||||
currentEditingCommand.value = commandData;
|
||||
newBindingKeyCombo.value = commandData.keybinding ? commandData.keybinding.combo : null;
|
||||
editDialogVisible.value = true;
|
||||
}
|
||||
__name(editKeybinding, "editKeybinding");
|
||||
watchEffect(() => {
|
||||
if (editDialogVisible.value) {
|
||||
setTimeout(() => {
|
||||
keybindingInput.value?.$el?.focus();
|
||||
}, 300);
|
||||
}
|
||||
});
|
||||
function removeKeybinding(commandData) {
|
||||
if (commandData.keybinding) {
|
||||
keybindingStore.unsetKeybinding(commandData.keybinding);
|
||||
keybindingStore.persistUserKeybindings();
|
||||
}
|
||||
}
|
||||
__name(removeKeybinding, "removeKeybinding");
|
||||
function captureKeybinding(event) {
|
||||
const keyCombo = KeyComboImpl.fromEvent(event);
|
||||
newBindingKeyCombo.value = keyCombo;
|
||||
}
|
||||
__name(captureKeybinding, "captureKeybinding");
|
||||
function cancelEdit() {
|
||||
editDialogVisible.value = false;
|
||||
currentEditingCommand.value = null;
|
||||
newBindingKeyCombo.value = null;
|
||||
}
|
||||
__name(cancelEdit, "cancelEdit");
|
||||
function saveKeybinding() {
|
||||
if (currentEditingCommand.value && newBindingKeyCombo.value) {
|
||||
const updated = keybindingStore.updateKeybindingOnCommand(
|
||||
new KeybindingImpl({
|
||||
commandId: currentEditingCommand.value.id,
|
||||
combo: newBindingKeyCombo.value
|
||||
})
|
||||
);
|
||||
if (updated) {
|
||||
keybindingStore.persistUserKeybindings();
|
||||
}
|
||||
}
|
||||
cancelEdit();
|
||||
}
|
||||
__name(saveKeybinding, "saveKeybinding");
|
||||
const toast = useToast();
|
||||
async function resetKeybindings() {
|
||||
keybindingStore.resetKeybindings();
|
||||
await keybindingStore.persistUserKeybindings();
|
||||
toast.add({
|
||||
severity: "info",
|
||||
summary: "Info",
|
||||
detail: "Keybindings reset",
|
||||
life: 3e3
|
||||
});
|
||||
}
|
||||
__name(resetKeybindings, "resetKeybindings");
|
||||
return (_ctx, _cache) => {
|
||||
const _directive_tooltip = resolveDirective("tooltip");
|
||||
return openBlock(), createElementBlock("div", _hoisted_1, [
|
||||
createVNode(unref(script$3), {
|
||||
value: commandsData.value,
|
||||
selection: selectedCommandData.value,
|
||||
"onUpdate:selection": _cache[1] || (_cache[1] = ($event) => selectedCommandData.value = $event),
|
||||
"global-filter-fields": ["id"],
|
||||
filters: filters.value,
|
||||
selectionMode: "single",
|
||||
stripedRows: "",
|
||||
pt: {
|
||||
header: "px-0"
|
||||
}
|
||||
}, {
|
||||
header: withCtx(() => [
|
||||
createVNode(SearchBox, {
|
||||
modelValue: filters.value["global"].value,
|
||||
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => filters.value["global"].value = $event),
|
||||
placeholder: _ctx.$t("searchKeybindings") + "..."
|
||||
}, null, 8, ["modelValue", "placeholder"])
|
||||
]),
|
||||
default: withCtx(() => [
|
||||
createVNode(unref(script$1), {
|
||||
field: "actions",
|
||||
header: ""
|
||||
}, {
|
||||
body: withCtx((slotProps) => [
|
||||
createBaseVNode("div", _hoisted_2, [
|
||||
createVNode(unref(script$2), {
|
||||
icon: "pi pi-pencil",
|
||||
class: "p-button-text",
|
||||
onClick: /* @__PURE__ */ __name(($event) => editKeybinding(slotProps.data), "onClick")
|
||||
}, null, 8, ["onClick"]),
|
||||
createVNode(unref(script$2), {
|
||||
icon: "pi pi-trash",
|
||||
class: "p-button-text p-button-danger",
|
||||
onClick: /* @__PURE__ */ __name(($event) => removeKeybinding(slotProps.data), "onClick"),
|
||||
disabled: !slotProps.data.keybinding
|
||||
}, null, 8, ["onClick", "disabled"])
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
}),
|
||||
createVNode(unref(script$1), {
|
||||
field: "id",
|
||||
header: "Command ID",
|
||||
sortable: ""
|
||||
}),
|
||||
createVNode(unref(script$1), {
|
||||
field: "keybinding",
|
||||
header: "Keybinding"
|
||||
}, {
|
||||
body: withCtx((slotProps) => [
|
||||
slotProps.data.keybinding ? (openBlock(), createBlock(_sfc_main$1, {
|
||||
key: 0,
|
||||
keyCombo: slotProps.data.keybinding.combo,
|
||||
isModified: unref(keybindingStore).isCommandKeybindingModified(slotProps.data.id)
|
||||
}, null, 8, ["keyCombo", "isModified"])) : (openBlock(), createElementBlock("span", _hoisted_3, "-"))
|
||||
]),
|
||||
_: 1
|
||||
})
|
||||
]),
|
||||
_: 1
|
||||
}, 8, ["value", "selection", "filters"]),
|
||||
createVNode(unref(script$6), {
|
||||
class: "min-w-96",
|
||||
visible: editDialogVisible.value,
|
||||
"onUpdate:visible": _cache[2] || (_cache[2] = ($event) => editDialogVisible.value = $event),
|
||||
modal: "",
|
||||
header: currentEditingCommand.value?.id,
|
||||
onHide: cancelEdit
|
||||
}, {
|
||||
footer: withCtx(() => [
|
||||
createVNode(unref(script$2), {
|
||||
label: "Save",
|
||||
icon: "pi pi-check",
|
||||
onClick: saveKeybinding,
|
||||
disabled: !!existingKeybindingOnCombo.value,
|
||||
autofocus: ""
|
||||
}, null, 8, ["disabled"])
|
||||
]),
|
||||
default: withCtx(() => [
|
||||
createBaseVNode("div", null, [
|
||||
createVNode(unref(script$4), {
|
||||
class: "mb-2 text-center",
|
||||
ref_key: "keybindingInput",
|
||||
ref: keybindingInput,
|
||||
modelValue: newBindingKeyCombo.value?.toString() ?? "",
|
||||
placeholder: "Press keys for new binding",
|
||||
onKeydown: withModifiers(captureKeybinding, ["stop", "prevent"]),
|
||||
autocomplete: "off",
|
||||
fluid: "",
|
||||
invalid: !!existingKeybindingOnCombo.value
|
||||
}, null, 8, ["modelValue", "invalid"]),
|
||||
existingKeybindingOnCombo.value ? (openBlock(), createBlock(unref(script$5), {
|
||||
key: 0,
|
||||
severity: "error"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createTextVNode(" Keybinding already exists on "),
|
||||
createVNode(unref(script), {
|
||||
severity: "secondary",
|
||||
value: existingKeybindingOnCombo.value.commandId
|
||||
}, null, 8, ["value"])
|
||||
]),
|
||||
_: 1
|
||||
})) : createCommentVNode("", true)
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
}, 8, ["visible", "header"]),
|
||||
withDirectives(createVNode(unref(script$2), {
|
||||
class: "mt-4",
|
||||
label: _ctx.$t("reset"),
|
||||
icon: "pi pi-trash",
|
||||
severity: "danger",
|
||||
fluid: "",
|
||||
text: "",
|
||||
onClick: resetKeybindings
|
||||
}, null, 8, ["label"]), [
|
||||
[_directive_tooltip, _ctx.$t("resetKeybindingsTooltip")]
|
||||
])
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-e5724e4d"]]);
|
||||
export {
|
||||
KeybindingPanel as default
|
||||
};
|
||||
//# sourceMappingURL=KeybindingPanel-Dm_3sBT5.js.map
|
||||
1
web/assets/KeybindingPanel-Dm_3sBT5.js.map
generated
vendored
1
web/assets/KeybindingPanel-Dm_3sBT5.js.map
generated
vendored
File diff suppressed because one or more lines are too long
279
web/assets/KeybindingPanel-lcJrxHwZ.js
generated
vendored
Normal file
279
web/assets/KeybindingPanel-lcJrxHwZ.js
generated
vendored
Normal file
@@ -0,0 +1,279 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, q as computed, g as openBlock, h as createElementBlock, N as Fragment, O as renderList, i as createVNode, y as withCtx, aw as createTextVNode, a6 as toDisplayString, z as unref, aA as script, j as createCommentVNode, r as ref, c6 as FilterMatchMode, M as useKeybindingStore, F as useCommandStore, aJ as watchEffect, bg as useToast, t as resolveDirective, x as createBlock, c7 as SearchBox, A as createBaseVNode, D as script$2, ao as script$4, bk as withModifiers, bT as script$5, aH as script$6, v as withDirectives, c8 as _sfc_main$2, P as pushScopeId, Q as popScopeId, c1 as KeyComboImpl, c9 as KeybindingImpl, _ as _export_sfc } from "./index-CoOvI8ZH.js";
|
||||
import { s as script$1, a as script$3 } from "./index-DK6Kev7f.js";
|
||||
import "./index-D4DWQPPQ.js";
|
||||
const _hoisted_1$1 = {
|
||||
key: 0,
|
||||
class: "px-2"
|
||||
};
|
||||
const _sfc_main$1 = /* @__PURE__ */ defineComponent({
|
||||
__name: "KeyComboDisplay",
|
||||
props: {
|
||||
keyCombo: {},
|
||||
isModified: { type: Boolean, default: false }
|
||||
},
|
||||
setup(__props) {
|
||||
const props = __props;
|
||||
const keySequences = computed(() => props.keyCombo.getKeySequences());
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("span", null, [
|
||||
(openBlock(true), createElementBlock(Fragment, null, renderList(keySequences.value, (sequence, index) => {
|
||||
return openBlock(), createElementBlock(Fragment, { key: index }, [
|
||||
createVNode(unref(script), {
|
||||
severity: _ctx.isModified ? "info" : "secondary"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createTextVNode(toDisplayString(sequence), 1)
|
||||
]),
|
||||
_: 2
|
||||
}, 1032, ["severity"]),
|
||||
index < keySequences.value.length - 1 ? (openBlock(), createElementBlock("span", _hoisted_1$1, "+")) : createCommentVNode("", true)
|
||||
], 64);
|
||||
}), 128))
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-9d7e362e"), n = n(), popScopeId(), n), "_withScopeId");
|
||||
const _hoisted_1 = { class: "actions invisible flex flex-row" };
|
||||
const _hoisted_2 = ["title"];
|
||||
const _hoisted_3 = { key: 1 };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "KeybindingPanel",
|
||||
setup(__props) {
|
||||
const filters = ref({
|
||||
global: { value: "", matchMode: FilterMatchMode.CONTAINS }
|
||||
});
|
||||
const keybindingStore = useKeybindingStore();
|
||||
const commandStore = useCommandStore();
|
||||
const commandsData = computed(() => {
|
||||
return Object.values(commandStore.commands).map((command) => ({
|
||||
id: command.id,
|
||||
keybinding: keybindingStore.getKeybindingByCommandId(command.id)
|
||||
}));
|
||||
});
|
||||
const selectedCommandData = ref(null);
|
||||
const editDialogVisible = ref(false);
|
||||
const newBindingKeyCombo = ref(null);
|
||||
const currentEditingCommand = ref(null);
|
||||
const keybindingInput = ref(null);
|
||||
const existingKeybindingOnCombo = computed(() => {
|
||||
if (!currentEditingCommand.value) {
|
||||
return null;
|
||||
}
|
||||
if (currentEditingCommand.value.keybinding?.combo?.equals(
|
||||
newBindingKeyCombo.value
|
||||
)) {
|
||||
return null;
|
||||
}
|
||||
if (!newBindingKeyCombo.value) {
|
||||
return null;
|
||||
}
|
||||
return keybindingStore.getKeybinding(newBindingKeyCombo.value);
|
||||
});
|
||||
function editKeybinding(commandData) {
|
||||
currentEditingCommand.value = commandData;
|
||||
newBindingKeyCombo.value = commandData.keybinding ? commandData.keybinding.combo : null;
|
||||
editDialogVisible.value = true;
|
||||
}
|
||||
__name(editKeybinding, "editKeybinding");
|
||||
watchEffect(() => {
|
||||
if (editDialogVisible.value) {
|
||||
setTimeout(() => {
|
||||
keybindingInput.value?.$el?.focus();
|
||||
}, 300);
|
||||
}
|
||||
});
|
||||
function removeKeybinding(commandData) {
|
||||
if (commandData.keybinding) {
|
||||
keybindingStore.unsetKeybinding(commandData.keybinding);
|
||||
keybindingStore.persistUserKeybindings();
|
||||
}
|
||||
}
|
||||
__name(removeKeybinding, "removeKeybinding");
|
||||
function captureKeybinding(event) {
|
||||
const keyCombo = KeyComboImpl.fromEvent(event);
|
||||
newBindingKeyCombo.value = keyCombo;
|
||||
}
|
||||
__name(captureKeybinding, "captureKeybinding");
|
||||
function cancelEdit() {
|
||||
editDialogVisible.value = false;
|
||||
currentEditingCommand.value = null;
|
||||
newBindingKeyCombo.value = null;
|
||||
}
|
||||
__name(cancelEdit, "cancelEdit");
|
||||
function saveKeybinding() {
|
||||
if (currentEditingCommand.value && newBindingKeyCombo.value) {
|
||||
const updated = keybindingStore.updateKeybindingOnCommand(
|
||||
new KeybindingImpl({
|
||||
commandId: currentEditingCommand.value.id,
|
||||
combo: newBindingKeyCombo.value
|
||||
})
|
||||
);
|
||||
if (updated) {
|
||||
keybindingStore.persistUserKeybindings();
|
||||
}
|
||||
}
|
||||
cancelEdit();
|
||||
}
|
||||
__name(saveKeybinding, "saveKeybinding");
|
||||
const toast = useToast();
|
||||
async function resetKeybindings() {
|
||||
keybindingStore.resetKeybindings();
|
||||
await keybindingStore.persistUserKeybindings();
|
||||
toast.add({
|
||||
severity: "info",
|
||||
summary: "Info",
|
||||
detail: "Keybindings reset",
|
||||
life: 3e3
|
||||
});
|
||||
}
|
||||
__name(resetKeybindings, "resetKeybindings");
|
||||
return (_ctx, _cache) => {
|
||||
const _directive_tooltip = resolveDirective("tooltip");
|
||||
return openBlock(), createBlock(_sfc_main$2, {
|
||||
value: "Keybinding",
|
||||
class: "keybinding-panel"
|
||||
}, {
|
||||
header: withCtx(() => [
|
||||
createVNode(SearchBox, {
|
||||
modelValue: filters.value["global"].value,
|
||||
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => filters.value["global"].value = $event),
|
||||
placeholder: _ctx.$t("searchKeybindings") + "..."
|
||||
}, null, 8, ["modelValue", "placeholder"])
|
||||
]),
|
||||
default: withCtx(() => [
|
||||
createVNode(unref(script$3), {
|
||||
value: commandsData.value,
|
||||
selection: selectedCommandData.value,
|
||||
"onUpdate:selection": _cache[1] || (_cache[1] = ($event) => selectedCommandData.value = $event),
|
||||
"global-filter-fields": ["id"],
|
||||
filters: filters.value,
|
||||
selectionMode: "single",
|
||||
stripedRows: "",
|
||||
pt: {
|
||||
header: "px-0"
|
||||
}
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createVNode(unref(script$1), {
|
||||
field: "actions",
|
||||
header: ""
|
||||
}, {
|
||||
body: withCtx((slotProps) => [
|
||||
createBaseVNode("div", _hoisted_1, [
|
||||
createVNode(unref(script$2), {
|
||||
icon: "pi pi-pencil",
|
||||
class: "p-button-text",
|
||||
onClick: /* @__PURE__ */ __name(($event) => editKeybinding(slotProps.data), "onClick")
|
||||
}, null, 8, ["onClick"]),
|
||||
createVNode(unref(script$2), {
|
||||
icon: "pi pi-trash",
|
||||
class: "p-button-text p-button-danger",
|
||||
onClick: /* @__PURE__ */ __name(($event) => removeKeybinding(slotProps.data), "onClick"),
|
||||
disabled: !slotProps.data.keybinding
|
||||
}, null, 8, ["onClick", "disabled"])
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
}),
|
||||
createVNode(unref(script$1), {
|
||||
field: "id",
|
||||
header: "Command ID",
|
||||
sortable: "",
|
||||
class: "max-w-64 2xl:max-w-full"
|
||||
}, {
|
||||
body: withCtx((slotProps) => [
|
||||
createBaseVNode("div", {
|
||||
class: "overflow-hidden text-ellipsis whitespace-nowrap",
|
||||
title: slotProps.data.id
|
||||
}, toDisplayString(slotProps.data.id), 9, _hoisted_2)
|
||||
]),
|
||||
_: 1
|
||||
}),
|
||||
createVNode(unref(script$1), {
|
||||
field: "keybinding",
|
||||
header: "Keybinding"
|
||||
}, {
|
||||
body: withCtx((slotProps) => [
|
||||
slotProps.data.keybinding ? (openBlock(), createBlock(_sfc_main$1, {
|
||||
key: 0,
|
||||
keyCombo: slotProps.data.keybinding.combo,
|
||||
isModified: unref(keybindingStore).isCommandKeybindingModified(slotProps.data.id)
|
||||
}, null, 8, ["keyCombo", "isModified"])) : (openBlock(), createElementBlock("span", _hoisted_3, "-"))
|
||||
]),
|
||||
_: 1
|
||||
})
|
||||
]),
|
||||
_: 1
|
||||
}, 8, ["value", "selection", "filters"]),
|
||||
createVNode(unref(script$6), {
|
||||
class: "min-w-96",
|
||||
visible: editDialogVisible.value,
|
||||
"onUpdate:visible": _cache[2] || (_cache[2] = ($event) => editDialogVisible.value = $event),
|
||||
modal: "",
|
||||
header: currentEditingCommand.value?.id,
|
||||
onHide: cancelEdit
|
||||
}, {
|
||||
footer: withCtx(() => [
|
||||
createVNode(unref(script$2), {
|
||||
label: "Save",
|
||||
icon: "pi pi-check",
|
||||
onClick: saveKeybinding,
|
||||
disabled: !!existingKeybindingOnCombo.value,
|
||||
autofocus: ""
|
||||
}, null, 8, ["disabled"])
|
||||
]),
|
||||
default: withCtx(() => [
|
||||
createBaseVNode("div", null, [
|
||||
createVNode(unref(script$4), {
|
||||
class: "mb-2 text-center",
|
||||
ref_key: "keybindingInput",
|
||||
ref: keybindingInput,
|
||||
modelValue: newBindingKeyCombo.value?.toString() ?? "",
|
||||
placeholder: "Press keys for new binding",
|
||||
onKeydown: withModifiers(captureKeybinding, ["stop", "prevent"]),
|
||||
autocomplete: "off",
|
||||
fluid: "",
|
||||
invalid: !!existingKeybindingOnCombo.value
|
||||
}, null, 8, ["modelValue", "invalid"]),
|
||||
existingKeybindingOnCombo.value ? (openBlock(), createBlock(unref(script$5), {
|
||||
key: 0,
|
||||
severity: "error"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createTextVNode(" Keybinding already exists on "),
|
||||
createVNode(unref(script), {
|
||||
severity: "secondary",
|
||||
value: existingKeybindingOnCombo.value.commandId
|
||||
}, null, 8, ["value"])
|
||||
]),
|
||||
_: 1
|
||||
})) : createCommentVNode("", true)
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
}, 8, ["visible", "header"]),
|
||||
withDirectives(createVNode(unref(script$2), {
|
||||
class: "mt-4",
|
||||
label: _ctx.$t("reset"),
|
||||
icon: "pi pi-trash",
|
||||
severity: "danger",
|
||||
fluid: "",
|
||||
text: "",
|
||||
onClick: resetKeybindings
|
||||
}, null, 8, ["label"]), [
|
||||
[_directive_tooltip, _ctx.$t("resetKeybindingsTooltip")]
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
});
|
||||
};
|
||||
}
|
||||
});
|
||||
const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-9d7e362e"]]);
|
||||
export {
|
||||
KeybindingPanel as default
|
||||
};
|
||||
//# sourceMappingURL=KeybindingPanel-lcJrxHwZ.js.map
|
||||
1
web/assets/KeybindingPanel-lcJrxHwZ.js.map
generated
vendored
Normal file
1
web/assets/KeybindingPanel-lcJrxHwZ.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
150
web/assets/ServerConfigPanel-x68ubY-c.js
generated
vendored
Normal file
150
web/assets/ServerConfigPanel-x68ubY-c.js
generated
vendored
Normal file
@@ -0,0 +1,150 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { A as createBaseVNode, g as openBlock, h as createElementBlock, aU as markRaw, d as defineComponent, u as useSettingStore, bw as storeToRefs, w as watch, cy as useCopyToClipboard, x as createBlock, y as withCtx, z as unref, bT as script, a6 as toDisplayString, O as renderList, N as Fragment, i as createVNode, D as script$1, j as createCommentVNode, bI as script$2, cz as formatCamelCase, cA as FormItem, c8 as _sfc_main$1, bN as electronAPI } from "./index-CoOvI8ZH.js";
|
||||
import { u as useServerConfigStore } from "./serverConfigStore-cctR8PGG.js";
|
||||
const _hoisted_1$1 = {
|
||||
viewBox: "0 0 24 24",
|
||||
width: "1.2em",
|
||||
height: "1.2em"
|
||||
};
|
||||
const _hoisted_2$1 = /* @__PURE__ */ createBaseVNode("path", {
|
||||
fill: "none",
|
||||
stroke: "currentColor",
|
||||
"stroke-linecap": "round",
|
||||
"stroke-linejoin": "round",
|
||||
"stroke-width": "2",
|
||||
d: "m4 17l6-6l-6-6m8 14h8"
|
||||
}, null, -1);
|
||||
const _hoisted_3$1 = [
|
||||
_hoisted_2$1
|
||||
];
|
||||
function render(_ctx, _cache) {
|
||||
return openBlock(), createElementBlock("svg", _hoisted_1$1, [..._hoisted_3$1]);
|
||||
}
|
||||
__name(render, "render");
|
||||
const __unplugin_components_0 = markRaw({ name: "lucide-terminal", render });
|
||||
const _hoisted_1 = { class: "flex flex-col gap-2" };
|
||||
const _hoisted_2 = { class: "flex justify-end gap-2" };
|
||||
const _hoisted_3 = { class: "flex items-center justify-between" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "ServerConfigPanel",
|
||||
setup(__props) {
|
||||
const settingStore = useSettingStore();
|
||||
const serverConfigStore = useServerConfigStore();
|
||||
const {
|
||||
serverConfigsByCategory,
|
||||
serverConfigValues,
|
||||
launchArgs,
|
||||
commandLineArgs,
|
||||
modifiedConfigs
|
||||
} = storeToRefs(serverConfigStore);
|
||||
const revertChanges = /* @__PURE__ */ __name(() => {
|
||||
serverConfigStore.revertChanges();
|
||||
}, "revertChanges");
|
||||
const restartApp = /* @__PURE__ */ __name(() => {
|
||||
electronAPI().restartApp();
|
||||
}, "restartApp");
|
||||
watch(launchArgs, (newVal) => {
|
||||
settingStore.set("Comfy.Server.LaunchArgs", newVal);
|
||||
});
|
||||
watch(serverConfigValues, (newVal) => {
|
||||
settingStore.set("Comfy.Server.ServerConfigValues", newVal);
|
||||
});
|
||||
const { copyToClipboard } = useCopyToClipboard();
|
||||
const copyCommandLineArgs = /* @__PURE__ */ __name(async () => {
|
||||
await copyToClipboard(commandLineArgs.value);
|
||||
}, "copyCommandLineArgs");
|
||||
return (_ctx, _cache) => {
|
||||
const _component_i_lucide58terminal = __unplugin_components_0;
|
||||
return openBlock(), createBlock(_sfc_main$1, {
|
||||
value: "Server-Config",
|
||||
class: "server-config-panel"
|
||||
}, {
|
||||
header: withCtx(() => [
|
||||
createBaseVNode("div", _hoisted_1, [
|
||||
unref(modifiedConfigs).length > 0 ? (openBlock(), createBlock(unref(script), {
|
||||
key: 0,
|
||||
severity: "info",
|
||||
"pt:text": "w-full"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createBaseVNode("p", null, toDisplayString(_ctx.$t("serverConfig.modifiedConfigs")), 1),
|
||||
createBaseVNode("ul", null, [
|
||||
(openBlock(true), createElementBlock(Fragment, null, renderList(unref(modifiedConfigs), (config) => {
|
||||
return openBlock(), createElementBlock("li", {
|
||||
key: config.id
|
||||
}, toDisplayString(config.name) + ": " + toDisplayString(config.initialValue) + " → " + toDisplayString(config.value), 1);
|
||||
}), 128))
|
||||
]),
|
||||
createBaseVNode("div", _hoisted_2, [
|
||||
createVNode(unref(script$1), {
|
||||
label: _ctx.$t("serverConfig.revertChanges"),
|
||||
onClick: revertChanges,
|
||||
outlined: ""
|
||||
}, null, 8, ["label"]),
|
||||
createVNode(unref(script$1), {
|
||||
label: _ctx.$t("serverConfig.restart"),
|
||||
onClick: restartApp,
|
||||
outlined: "",
|
||||
severity: "danger"
|
||||
}, null, 8, ["label"])
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
})) : createCommentVNode("", true),
|
||||
unref(commandLineArgs) ? (openBlock(), createBlock(unref(script), {
|
||||
key: 1,
|
||||
severity: "secondary",
|
||||
"pt:text": "w-full"
|
||||
}, {
|
||||
icon: withCtx(() => [
|
||||
createVNode(_component_i_lucide58terminal, { class: "text-xl font-bold" })
|
||||
]),
|
||||
default: withCtx(() => [
|
||||
createBaseVNode("div", _hoisted_3, [
|
||||
createBaseVNode("p", null, toDisplayString(unref(commandLineArgs)), 1),
|
||||
createVNode(unref(script$1), {
|
||||
icon: "pi pi-clipboard",
|
||||
onClick: copyCommandLineArgs,
|
||||
severity: "secondary",
|
||||
text: ""
|
||||
})
|
||||
])
|
||||
]),
|
||||
_: 1
|
||||
})) : createCommentVNode("", true)
|
||||
])
|
||||
]),
|
||||
default: withCtx(() => [
|
||||
(openBlock(true), createElementBlock(Fragment, null, renderList(Object.entries(unref(serverConfigsByCategory)), ([label, items], i) => {
|
||||
return openBlock(), createElementBlock("div", { key: label }, [
|
||||
i > 0 ? (openBlock(), createBlock(unref(script$2), { key: 0 })) : createCommentVNode("", true),
|
||||
createBaseVNode("h3", null, toDisplayString(unref(formatCamelCase)(label)), 1),
|
||||
(openBlock(true), createElementBlock(Fragment, null, renderList(items, (item) => {
|
||||
return openBlock(), createElementBlock("div", {
|
||||
key: item.name,
|
||||
class: "flex items-center mb-4"
|
||||
}, [
|
||||
createVNode(FormItem, {
|
||||
item,
|
||||
formValue: item.value,
|
||||
"onUpdate:formValue": /* @__PURE__ */ __name(($event) => item.value = $event, "onUpdate:formValue"),
|
||||
id: item.id,
|
||||
labelClass: {
|
||||
"text-highlight": item.initialValue !== item.value
|
||||
}
|
||||
}, null, 8, ["item", "formValue", "onUpdate:formValue", "id", "labelClass"])
|
||||
]);
|
||||
}), 128))
|
||||
]);
|
||||
}), 128))
|
||||
]),
|
||||
_: 1
|
||||
});
|
||||
};
|
||||
}
|
||||
});
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=ServerConfigPanel-x68ubY-c.js.map
|
||||
1
web/assets/ServerConfigPanel-x68ubY-c.js.map
generated
vendored
Normal file
1
web/assets/ServerConfigPanel-x68ubY-c.js.map
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
{"version":3,"file":"ServerConfigPanel-x68ubY-c.js","sources":["../../src/components/dialog/content/setting/ServerConfigPanel.vue"],"sourcesContent":["<template>\n <PanelTemplate value=\"Server-Config\" class=\"server-config-panel\">\n <template #header>\n <div class=\"flex flex-col gap-2\">\n <Message\n v-if=\"modifiedConfigs.length > 0\"\n severity=\"info\"\n pt:text=\"w-full\"\n >\n <p>\n {{ $t('serverConfig.modifiedConfigs') }}\n </p>\n <ul>\n <li v-for=\"config in modifiedConfigs\" :key=\"config.id\">\n {{ config.name }}: {{ config.initialValue }} → {{ config.value }}\n </li>\n </ul>\n <div class=\"flex justify-end gap-2\">\n <Button\n :label=\"$t('serverConfig.revertChanges')\"\n @click=\"revertChanges\"\n outlined\n />\n <Button\n :label=\"$t('serverConfig.restart')\"\n @click=\"restartApp\"\n outlined\n severity=\"danger\"\n />\n </div>\n </Message>\n <Message v-if=\"commandLineArgs\" severity=\"secondary\" pt:text=\"w-full\">\n <template #icon>\n <i-lucide:terminal class=\"text-xl font-bold\" />\n </template>\n <div class=\"flex items-center justify-between\">\n <p>{{ commandLineArgs }}</p>\n <Button\n icon=\"pi pi-clipboard\"\n @click=\"copyCommandLineArgs\"\n severity=\"secondary\"\n text\n />\n </div>\n </Message>\n </div>\n </template>\n <div\n v-for=\"([label, items], i) in Object.entries(serverConfigsByCategory)\"\n :key=\"label\"\n >\n <Divider v-if=\"i > 0\" />\n <h3>{{ formatCamelCase(label) }}</h3>\n <div\n v-for=\"item in items\"\n :key=\"item.name\"\n class=\"flex items-center mb-4\"\n >\n <FormItem\n :item=\"item\"\n v-model:formValue=\"item.value\"\n :id=\"item.id\"\n :labelClass=\"{\n 'text-highlight': item.initialValue !== item.value\n }\"\n />\n </div>\n </div>\n </PanelTemplate>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\nimport Divider from 'primevue/divider'\nimport FormItem from '@/components/common/FormItem.vue'\nimport PanelTemplate from './PanelTemplate.vue'\nimport { formatCamelCase } from '@/utils/formatUtil'\nimport { useServerConfigStore } from '@/stores/serverConfigStore'\nimport { storeToRefs } from 'pinia'\nimport { electronAPI } from '@/utils/envUtil'\nimport { useSettingStore } from '@/stores/settingStore'\nimport { watch } from 'vue'\nimport { useCopyToClipboard } from '@/hooks/clipboardHooks'\n\nconst settingStore = useSettingStore()\nconst serverConfigStore = useServerConfigStore()\nconst {\n serverConfigsByCategory,\n serverConfigValues,\n launchArgs,\n commandLineArgs,\n modifiedConfigs\n} = storeToRefs(serverConfigStore)\n\nconst revertChanges = () => {\n serverConfigStore.revertChanges()\n}\n\nconst restartApp = () => {\n electronAPI().restartApp()\n}\n\nwatch(launchArgs, (newVal) => {\n settingStore.set('Comfy.Server.LaunchArgs', newVal)\n})\n\nwatch(serverConfigValues, (newVal) => {\n settingStore.set('Comfy.Server.ServerConfigValues', newVal)\n})\n\nconst { copyToClipboard } = useCopyToClipboard()\nconst copyCommandLineArgs = async () => {\n await copyToClipboard(commandLineArgs.value)\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;AAqFA,UAAM,eAAe;AACrB,UAAM,oBAAoB;AACpB,UAAA;AAAA,MACJ;AAAA,MACA;AAAA,MACA;AAAA,MACA;AAAA,MACA;AAAA,IAAA,IACE,YAAY,iBAAiB;AAEjC,UAAM,gBAAgB,6BAAM;AAC1B,wBAAkB,cAAc;AAAA,IAAA,GADZ;AAItB,UAAM,aAAa,6BAAM;AACvB,kBAAA,EAAc;IAAW,GADR;AAIb,UAAA,YAAY,CAAC,WAAW;AACf,mBAAA,IAAI,2BAA2B,MAAM;AAAA,IAAA,CACnD;AAEK,UAAA,oBAAoB,CAAC,WAAW;AACvB,mBAAA,IAAI,mCAAmC,MAAM;AAAA,IAAA,CAC3D;AAEK,UAAA,EAAE,oBAAoB;AAC5B,UAAM,sBAAsB,mCAAY;AAChC,YAAA,gBAAgB,gBAAgB,KAAK;AAAA,IAAA,GADjB;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||
79
web/assets/ServerStartView-CqRVtr1h.js
generated
vendored
Normal file
79
web/assets/ServerStartView-CqRVtr1h.js
generated
vendored
Normal file
@@ -0,0 +1,79 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, aD as useI18n, r as ref, o as onMounted, g as openBlock, h as createElementBlock, A as createBaseVNode, aw as createTextVNode, a6 as toDisplayString, z as unref, j as createCommentVNode, i as createVNode, D as script, bM as BaseTerminal, P as pushScopeId, Q as popScopeId, bN as electronAPI, _ as _export_sfc } from "./index-CoOvI8ZH.js";
|
||||
import { P as ProgressStatus } from "./index-BppSBmxJ.js";
|
||||
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-f5429be7"), n = n(), popScopeId(), n), "_withScopeId");
|
||||
const _hoisted_1 = { class: "font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto" };
|
||||
const _hoisted_2 = { class: "text-2xl font-bold" };
|
||||
const _hoisted_3 = { key: 0 };
|
||||
const _hoisted_4 = {
|
||||
key: 0,
|
||||
class: "flex items-center my-4 gap-2"
|
||||
};
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "ServerStartView",
|
||||
setup(__props) {
|
||||
const electron = electronAPI();
|
||||
const { t } = useI18n();
|
||||
const status = ref(ProgressStatus.INITIAL_STATE);
|
||||
const electronVersion = ref("");
|
||||
let xterm;
|
||||
const updateProgress = /* @__PURE__ */ __name(({ status: newStatus }) => {
|
||||
status.value = newStatus;
|
||||
xterm?.clear();
|
||||
}, "updateProgress");
|
||||
const terminalCreated = /* @__PURE__ */ __name(({ terminal, useAutoSize }, root) => {
|
||||
xterm = terminal;
|
||||
useAutoSize(root, true, true);
|
||||
electron.onLogMessage((message) => {
|
||||
terminal.write(message);
|
||||
});
|
||||
terminal.options.cursorBlink = false;
|
||||
terminal.options.disableStdin = true;
|
||||
terminal.options.cursorInactiveStyle = "block";
|
||||
}, "terminalCreated");
|
||||
const reinstall = /* @__PURE__ */ __name(() => electron.reinstall(), "reinstall");
|
||||
const reportIssue = /* @__PURE__ */ __name(() => {
|
||||
window.open("https://forum.comfy.org/c/v1-feedback/", "_blank");
|
||||
}, "reportIssue");
|
||||
const openLogs = /* @__PURE__ */ __name(() => electron.openLogsFolder(), "openLogs");
|
||||
onMounted(async () => {
|
||||
electron.sendReady();
|
||||
electron.onProgressUpdate(updateProgress);
|
||||
electronVersion.value = await electron.getElectronVersion();
|
||||
});
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("div", _hoisted_1, [
|
||||
createBaseVNode("h2", _hoisted_2, [
|
||||
createTextVNode(toDisplayString(unref(t)(`serverStart.process.${status.value}`)) + " ", 1),
|
||||
status.value === unref(ProgressStatus).ERROR ? (openBlock(), createElementBlock("span", _hoisted_3, " v" + toDisplayString(electronVersion.value), 1)) : createCommentVNode("", true)
|
||||
]),
|
||||
status.value === unref(ProgressStatus).ERROR ? (openBlock(), createElementBlock("div", _hoisted_4, [
|
||||
createVNode(unref(script), {
|
||||
icon: "pi pi-flag",
|
||||
severity: "secondary",
|
||||
label: unref(t)("serverStart.reportIssue"),
|
||||
onClick: reportIssue
|
||||
}, null, 8, ["label"]),
|
||||
createVNode(unref(script), {
|
||||
icon: "pi pi-file",
|
||||
severity: "secondary",
|
||||
label: unref(t)("serverStart.openLogs"),
|
||||
onClick: openLogs
|
||||
}, null, 8, ["label"]),
|
||||
createVNode(unref(script), {
|
||||
icon: "pi pi-refresh",
|
||||
label: unref(t)("serverStart.reinstall"),
|
||||
onClick: reinstall
|
||||
}, null, 8, ["label"])
|
||||
])) : createCommentVNode("", true),
|
||||
createVNode(BaseTerminal, { onCreated: terminalCreated })
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
const ServerStartView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-f5429be7"]]);
|
||||
export {
|
||||
ServerStartView as default
|
||||
};
|
||||
//# sourceMappingURL=ServerStartView-CqRVtr1h.js.map
|
||||
1
web/assets/ServerStartView-CqRVtr1h.js.map
generated
vendored
Normal file
1
web/assets/ServerStartView-CqRVtr1h.js.map
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
{"version":3,"file":"ServerStartView-CqRVtr1h.js","sources":["../../src/views/ServerStartView.vue"],"sourcesContent":["<template>\n <div\n class=\"font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto\"\n >\n <h2 class=\"text-2xl font-bold\">\n {{ t(`serverStart.process.${status}`) }}\n <span v-if=\"status === ProgressStatus.ERROR\">\n v{{ electronVersion }}\n </span>\n </h2>\n <div\n v-if=\"status === ProgressStatus.ERROR\"\n class=\"flex items-center my-4 gap-2\"\n >\n <Button\n icon=\"pi pi-flag\"\n severity=\"secondary\"\n :label=\"t('serverStart.reportIssue')\"\n @click=\"reportIssue\"\n />\n <Button\n icon=\"pi pi-file\"\n severity=\"secondary\"\n :label=\"t('serverStart.openLogs')\"\n @click=\"openLogs\"\n />\n <Button\n icon=\"pi pi-refresh\"\n :label=\"t('serverStart.reinstall')\"\n @click=\"reinstall\"\n />\n </div>\n <BaseTerminal @created=\"terminalCreated\" />\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport { ref, onMounted, Ref } from 'vue'\nimport BaseTerminal from '@/components/bottomPanel/tabs/terminal/BaseTerminal.vue'\nimport { ProgressStatus } from '@comfyorg/comfyui-electron-types'\nimport { electronAPI } from '@/utils/envUtil'\nimport type { useTerminal } from '@/hooks/bottomPanelTabs/useTerminal'\nimport { Terminal } from '@xterm/xterm'\nimport { useI18n } from 'vue-i18n'\n\nconst electron = electronAPI()\nconst { t } = useI18n()\n\nconst status = ref<ProgressStatus>(ProgressStatus.INITIAL_STATE)\nconst electronVersion = ref<string>('')\nlet xterm: Terminal | undefined\n\nconst updateProgress = ({ status: newStatus }: { status: ProgressStatus }) => {\n status.value = newStatus\n xterm?.clear()\n}\n\nconst terminalCreated = (\n { terminal, useAutoSize }: ReturnType<typeof useTerminal>,\n root: Ref<HTMLElement>\n) => {\n xterm = terminal\n\n useAutoSize(root, true, true)\n electron.onLogMessage((message: string) => {\n terminal.write(message)\n })\n\n terminal.options.cursorBlink = false\n terminal.options.disableStdin = true\n terminal.options.cursorInactiveStyle = 'block'\n}\n\nconst reinstall = () => electron.reinstall()\nconst reportIssue = () => {\n window.open('https://forum.comfy.org/c/v1-feedback/', '_blank')\n}\nconst openLogs = () => electron.openLogsFolder()\n\nonMounted(async () => {\n electron.sendReady()\n electron.onProgressUpdate(updateProgress)\n electronVersion.value = await electron.getElectronVersion()\n})\n</script>\n\n<style scoped>\n:deep(.xterm-helper-textarea) {\n /* Hide this as it moves all over when uv is running */\n display: none;\n}\n</style>\n"],"names":[],"mappings":";;;;;;;;;;;;;;;AA8CA,UAAM,WAAW;AACX,UAAA,EAAE,MAAM;AAER,UAAA,SAAS,IAAoB,eAAe,aAAa;AACzD,UAAA,kBAAkB,IAAY,EAAE;AAClC,QAAA;AAEJ,UAAM,iBAAiB,wBAAC,EAAE,QAAQ,gBAA4C;AAC5E,aAAO,QAAQ;AACf,aAAO,MAAM;AAAA,IAAA,GAFQ;AAKvB,UAAM,kBAAkB,wBACtB,EAAE,UAAU,YAAA,GACZ,SACG;AACK,cAAA;AAEI,kBAAA,MAAM,MAAM,IAAI;AACnB,eAAA,aAAa,CAAC,YAAoB;AACzC,iBAAS,MAAM,OAAO;AAAA,MAAA,CACvB;AAED,eAAS,QAAQ,cAAc;AAC/B,eAAS,QAAQ,eAAe;AAChC,eAAS,QAAQ,sBAAsB;AAAA,IAAA,GAbjB;AAgBlB,UAAA,YAAY,6BAAM,SAAS,aAAf;AAClB,UAAM,cAAc,6BAAM;AACjB,aAAA,KAAK,0CAA0C,QAAQ;AAAA,IAAA,GAD5C;AAGd,UAAA,WAAW,6BAAM,SAAS,kBAAf;AAEjB,cAAU,YAAY;AACpB,eAAS,UAAU;AACnB,eAAS,iBAAiB,cAAc;AACxB,sBAAA,QAAQ,MAAM,SAAS,mBAAmB;AAAA,IAAA,CAC3D;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||
5
web/assets/ServerStartView-Djq8v91B.css
generated
vendored
Normal file
5
web/assets/ServerStartView-Djq8v91B.css
generated
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
|
||||
[data-v-f5429be7] .xterm-helper-textarea {
|
||||
/* Hide this as it moves all over when uv is running */
|
||||
display: none;
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user