Compare commits
112 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f9d5a244b | ||
|
|
14eba07acd | ||
|
|
4b2f0d9413 | ||
|
|
25eac1d780 | ||
|
|
e38c94228b | ||
|
|
203942c8b2 | ||
|
|
3c72c89a52 | ||
|
|
614377abd6 | ||
|
|
8dfa0cc552 | ||
|
|
e5ecdfdd2d | ||
|
|
7d29fbf74b | ||
|
|
2c641e64ad | ||
|
|
7d2467e830 | ||
|
|
6f021d8aa0 | ||
|
|
d854ed0bcf | ||
|
|
abcd006b8c | ||
|
|
d985d1d7dc | ||
|
|
d1cdf51e1b | ||
|
|
b4626ab93e | ||
|
|
a9e459c2a4 | ||
|
|
3bb4dec720 | ||
|
|
8733191563 | ||
|
|
83b01f960a | ||
|
|
d72e871cfa | ||
|
|
037c3159b6 | ||
|
|
bdd4a22a2e | ||
|
|
fdf37566ef | ||
|
|
08c8968482 | ||
|
|
479a427a48 | ||
|
|
3a0eeee320 | ||
|
|
447da7ea86 | ||
|
|
9c41bc8d10 | ||
|
|
6ad0ddbae4 | ||
|
|
a55142f904 | ||
|
|
5718ef69bb | ||
|
|
13ecf10a92 | ||
|
|
7a415f47a9 | ||
|
|
89fa2fca24 | ||
|
|
364b69e931 | ||
|
|
dc96a1ae19 | ||
|
|
2d810b081e | ||
|
|
9f7e9f0547 | ||
|
|
a355f38ecc | ||
|
|
38c69080c7 | ||
|
|
70a708d726 | ||
|
|
e7d4782736 | ||
|
|
3326bdfd4e | ||
|
|
68bb885d22 | ||
|
|
ad66f7c7d8 | ||
|
|
de8e8e3b0d | ||
|
|
a1e71cfad1 | ||
|
|
0bfc7cc998 | ||
|
|
7183fd1665 | ||
|
|
254838f23c | ||
|
|
0b7dfa986d | ||
|
|
d514bb38ee | ||
|
|
0849c80e2a | ||
|
|
56e8f5e4fd | ||
|
|
e813abbb2c | ||
|
|
5e68a4ce67 | ||
|
|
ca08597670 | ||
|
|
f48e390032 | ||
|
|
369a6dd2c4 | ||
|
|
b3ce8fb9fd | ||
|
|
cf80d28689 | ||
|
|
6fb44c4b7c | ||
|
|
d2247c1e61 | ||
|
|
cb12ad7049 | ||
|
|
f6b7194f64 | ||
|
|
7c6eb4fb29 | ||
|
|
b962db9952 | ||
|
|
d0b7ab88ba | ||
|
|
405b529545 | ||
|
|
9d720187f1 | ||
|
|
d247bc5a9c | ||
|
|
9f4daca9d9 | ||
|
|
b5d0f2a908 | ||
|
|
e760bf5c40 | ||
|
|
36c83cdbba | ||
|
|
81778a7feb | ||
|
|
bc94662b31 | ||
|
|
9fa8faa44a | ||
|
|
9a7444e39f | ||
|
|
54fca4a218 | ||
|
|
cd4955367e | ||
|
|
8354203d95 | ||
|
|
e0b41243b4 | ||
|
|
619263d4a6 | ||
|
|
e3b0402bb7 | ||
|
|
967867d48c | ||
|
|
cbaac71bf5 | ||
|
|
3ab3516e46 | ||
|
|
9c5fca75f4 | ||
|
|
a5da4d0b3e | ||
|
|
32a60a7bac | ||
|
|
bb52934ba4 | ||
|
|
8aabd7c8c0 | ||
|
|
a09b29ca11 | ||
|
|
9bfee68773 | ||
|
|
ea77750759 | ||
|
|
c27ebeb1c2 | ||
|
|
0c7c98a965 | ||
|
|
dc2eb75b85 | ||
|
|
fa34efe3bd | ||
|
|
5cbaa9e07c | ||
|
|
c7427375ee | ||
|
|
22d1241a50 | ||
|
|
f04229b84d | ||
|
|
f067ad15d1 | ||
|
|
483004dd1d | ||
|
|
00a5d08103 | ||
|
|
d043997d30 |
2
.github/workflows/pullrequest-ci-run.yml
vendored
2
.github/workflows/pullrequest-ci-run.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
runner_label: [self-hosted, Linux]
|
runner_label: [self-hosted, Linux]
|
||||||
flags: ""
|
flags: ""
|
||||||
- os: windows
|
- os: windows
|
||||||
runner_label: [self-hosted, win]
|
runner_label: [self-hosted, Windows]
|
||||||
flags: ""
|
flags: ""
|
||||||
runs-on: ${{ matrix.runner_label }}
|
runs-on: ${{ matrix.runner_label }}
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@@ -12,7 +12,7 @@ on:
|
|||||||
description: 'CUDA version'
|
description: 'CUDA version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "121"
|
default: "124"
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'Python minor version'
|
description: 'Python minor version'
|
||||||
required: true
|
required: true
|
||||||
|
|||||||
4
.github/workflows/test-ci.yml
vendored
4
.github/workflows/test-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
|||||||
runner_label: [self-hosted, Linux]
|
runner_label: [self-hosted, Linux]
|
||||||
flags: ""
|
flags: ""
|
||||||
- os: windows
|
- os: windows
|
||||||
runner_label: [self-hosted, win]
|
runner_label: [self-hosted, Windows]
|
||||||
flags: ""
|
flags: ""
|
||||||
runs-on: ${{ matrix.runner_label }}
|
runs-on: ${{ matrix.runner_label }}
|
||||||
steps:
|
steps:
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
torch_version: ["nightly"]
|
torch_version: ["nightly"]
|
||||||
include:
|
include:
|
||||||
- os: windows
|
- os: windows
|
||||||
runner_label: [self-hosted, win]
|
runner_label: [self-hosted, Windows]
|
||||||
flags: ""
|
flags: ""
|
||||||
runs-on: ${{ matrix.runner_label }}
|
runs-on: ${{ matrix.runner_label }}
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
30
.github/workflows/test-unit.yml
vendored
Normal file
30
.github/workflows/test-unit.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
name: Unit Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main, master ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main, master ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
continue-on-error: true
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
- name: Install requirements
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
pip install -r requirements.txt
|
||||||
|
- name: Run Unit Tests
|
||||||
|
run: |
|
||||||
|
pip install -r tests-unit/requirements.txt
|
||||||
|
python -m pytest tests-unit
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,6 +12,7 @@ extra_model_paths.yaml
|
|||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
venv/
|
venv/
|
||||||
|
.venv/
|
||||||
/web/extensions/*
|
/web/extensions/*
|
||||||
!/web/extensions/logging.js.example
|
!/web/extensions/logging.js.example
|
||||||
!/web/extensions/core/
|
!/web/extensions/core/
|
||||||
|
|||||||
@@ -94,6 +94,8 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
| Alt + `+` | Canvas Zoom in |
|
| Alt + `+` | Canvas Zoom in |
|
||||||
| Alt + `-` | Canvas Zoom out |
|
| Alt + `-` | Canvas Zoom out |
|
||||||
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
||||||
|
| P | Pin/Unpin selected nodes |
|
||||||
|
| Ctrl + G | Group selected nodes |
|
||||||
| Q | Toggle visibility of the queue |
|
| Q | Toggle visibility of the queue |
|
||||||
| H | Toggle visibility of history |
|
| H | Toggle visibility of history |
|
||||||
| R | Refresh graph |
|
| R | Refresh graph |
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from folder_paths import models_dir, user_directory, output_directory
|
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
||||||
from api_server.services.file_service import FileService
|
from api_server.services.file_service import FileService
|
||||||
import app.logger
|
import app.logger
|
||||||
|
|
||||||
@@ -36,6 +36,13 @@ class InternalRoutes:
|
|||||||
async def get_logs(request):
|
async def get_logs(request):
|
||||||
return web.json_response(app.logger.get_logs())
|
return web.json_response(app.logger.get_logs())
|
||||||
|
|
||||||
|
@self.routes.get('/folder_paths')
|
||||||
|
async def get_folder_paths(request):
|
||||||
|
response = {}
|
||||||
|
for key in folder_names_and_paths:
|
||||||
|
response[key] = folder_names_and_paths[key][0]
|
||||||
|
return web.json_response(response)
|
||||||
|
|
||||||
def get_app(self):
|
def get_app(self):
|
||||||
if self._app is None:
|
if self._app is None:
|
||||||
self._app = web.Application()
|
self._app = web.Application()
|
||||||
|
|||||||
@@ -10,14 +10,14 @@ def get_logs():
|
|||||||
return "\n".join([formatter.format(x) for x in logs])
|
return "\n".join([formatter.format(x) for x in logs])
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(verbose: bool = False, capacity: int = 300):
|
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||||
global logs
|
global logs
|
||||||
if logs:
|
if logs:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Setup default global logger
|
# Setup default global logger
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
stream_handler = logging.StreamHandler()
|
stream_handler = logging.StreamHandler()
|
||||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||||
|
|||||||
@@ -5,17 +5,17 @@ import uuid
|
|||||||
import glob
|
import glob
|
||||||
import shutil
|
import shutil
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
from urllib import parse
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from folder_paths import user_directory
|
import folder_paths
|
||||||
from .app_settings import AppSettings
|
from .app_settings import AppSettings
|
||||||
|
|
||||||
default_user = "default"
|
default_user = "default"
|
||||||
users_file = os.path.join(user_directory, "users.json")
|
|
||||||
|
|
||||||
|
|
||||||
class UserManager():
|
class UserManager():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
global user_directory
|
user_directory = folder_paths.get_user_directory()
|
||||||
|
|
||||||
self.settings = AppSettings(self)
|
self.settings = AppSettings(self)
|
||||||
if not os.path.exists(user_directory):
|
if not os.path.exists(user_directory):
|
||||||
@@ -25,14 +25,17 @@ class UserManager():
|
|||||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||||
|
|
||||||
if args.multi_user:
|
if args.multi_user:
|
||||||
if os.path.isfile(users_file):
|
if os.path.isfile(self.get_users_file()):
|
||||||
with open(users_file) as f:
|
with open(self.get_users_file()) as f:
|
||||||
self.users = json.load(f)
|
self.users = json.load(f)
|
||||||
else:
|
else:
|
||||||
self.users = {}
|
self.users = {}
|
||||||
else:
|
else:
|
||||||
self.users = {"default": "default"}
|
self.users = {"default": "default"}
|
||||||
|
|
||||||
|
def get_users_file(self):
|
||||||
|
return os.path.join(folder_paths.get_user_directory(), "users.json")
|
||||||
|
|
||||||
def get_request_user_id(self, request):
|
def get_request_user_id(self, request):
|
||||||
user = "default"
|
user = "default"
|
||||||
if args.multi_user and "comfy-user" in request.headers:
|
if args.multi_user and "comfy-user" in request.headers:
|
||||||
@@ -44,7 +47,7 @@ class UserManager():
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||||
global user_directory
|
user_directory = folder_paths.get_user_directory()
|
||||||
|
|
||||||
if type == "userdata":
|
if type == "userdata":
|
||||||
root_dir = user_directory
|
root_dir = user_directory
|
||||||
@@ -59,6 +62,10 @@ class UserManager():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if file is not None:
|
if file is not None:
|
||||||
|
# Check if filename is url encoded
|
||||||
|
if "%" in file:
|
||||||
|
file = parse.unquote(file)
|
||||||
|
|
||||||
# prevent leaving /{type}/{user}
|
# prevent leaving /{type}/{user}
|
||||||
path = os.path.abspath(os.path.join(user_root, file))
|
path = os.path.abspath(os.path.join(user_root, file))
|
||||||
if os.path.commonpath((user_root, path)) != user_root:
|
if os.path.commonpath((user_root, path)) != user_root:
|
||||||
@@ -80,8 +87,7 @@ class UserManager():
|
|||||||
|
|
||||||
self.users[user_id] = name
|
self.users[user_id] = name
|
||||||
|
|
||||||
global users_file
|
with open(self.get_users_file(), "w") as f:
|
||||||
with open(users_file, "w") as f:
|
|
||||||
json.dump(self.users, f)
|
json.dump(self.users, f)
|
||||||
|
|
||||||
return user_id
|
return user_id
|
||||||
@@ -112,25 +118,69 @@ class UserManager():
|
|||||||
|
|
||||||
@routes.get("/userdata")
|
@routes.get("/userdata")
|
||||||
async def listuserdata(request):
|
async def listuserdata(request):
|
||||||
|
"""
|
||||||
|
List user data files in a specified directory.
|
||||||
|
|
||||||
|
This endpoint allows listing files in a user's data directory, with options for recursion,
|
||||||
|
full file information, and path splitting.
|
||||||
|
|
||||||
|
Query Parameters:
|
||||||
|
- dir (required): The directory to list files from.
|
||||||
|
- recurse (optional): If "true", recursively list files in subdirectories.
|
||||||
|
- full_info (optional): If "true", return detailed file information (path, size, modified time).
|
||||||
|
- split (optional): If "true", split file paths into components (only applies when full_info is false).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- 400: If 'dir' parameter is missing.
|
||||||
|
- 403: If the requested path is not allowed.
|
||||||
|
- 404: If the requested directory does not exist.
|
||||||
|
- 200: JSON response with the list of files or file information.
|
||||||
|
|
||||||
|
The response format depends on the query parameters:
|
||||||
|
- Default: List of relative file paths.
|
||||||
|
- full_info=true: List of dictionaries with file details.
|
||||||
|
- split=true (and full_info=false): List of lists, each containing path components.
|
||||||
|
"""
|
||||||
directory = request.rel_url.query.get('dir', '')
|
directory = request.rel_url.query.get('dir', '')
|
||||||
if not directory:
|
if not directory:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400, text="Directory not provided")
|
||||||
|
|
||||||
path = self.get_request_user_filepath(request, directory)
|
path = self.get_request_user_filepath(request, directory)
|
||||||
if not path:
|
if not path:
|
||||||
return web.Response(status=403)
|
return web.Response(status=403, text="Invalid directory")
|
||||||
|
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
return web.Response(status=404)
|
return web.Response(status=404, text="Directory not found")
|
||||||
|
|
||||||
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
||||||
results = glob.glob(os.path.join(
|
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
|
||||||
glob.escape(path), '**/*'), recursive=recurse)
|
|
||||||
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
|
# Use different patterns based on whether we're recursing or not
|
||||||
|
if recurse:
|
||||||
|
pattern = os.path.join(glob.escape(path), '**', '*')
|
||||||
|
else:
|
||||||
|
pattern = os.path.join(glob.escape(path), '*')
|
||||||
|
|
||||||
|
results = glob.glob(pattern, recursive=recurse)
|
||||||
|
|
||||||
|
if full_info:
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
'path': os.path.relpath(x, path).replace(os.sep, '/'),
|
||||||
|
'size': os.path.getsize(x),
|
||||||
|
'modified': os.path.getmtime(x)
|
||||||
|
} for x in results if os.path.isfile(x)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
results = [
|
||||||
|
os.path.relpath(x, path).replace(os.sep, '/')
|
||||||
|
for x in results
|
||||||
|
if os.path.isfile(x)
|
||||||
|
]
|
||||||
|
|
||||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||||
if split_path:
|
if split_path and not full_info:
|
||||||
results = [[x] + x.split(os.sep) for x in results]
|
results = [[x] + x.split('/') for x in results]
|
||||||
|
|
||||||
return web.json_response(results)
|
return web.json_response(results)
|
||||||
|
|
||||||
@@ -138,14 +188,14 @@ class UserManager():
|
|||||||
file = request.match_info.get(param, None)
|
file = request.match_info.get(param, None)
|
||||||
if not file:
|
if not file:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
path = self.get_request_user_filepath(request, file)
|
path = self.get_request_user_filepath(request, file)
|
||||||
if not path:
|
if not path:
|
||||||
return web.Response(status=403)
|
return web.Response(status=403)
|
||||||
|
|
||||||
if check_exists and not os.path.exists(path):
|
if check_exists and not os.path.exists(path):
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
return path
|
return path
|
||||||
|
|
||||||
@routes.get("/userdata/{file}")
|
@routes.get("/userdata/{file}")
|
||||||
@@ -153,7 +203,7 @@ class UserManager():
|
|||||||
path = get_user_data_path(request, check_exists=True)
|
path = get_user_data_path(request, check_exists=True)
|
||||||
if not isinstance(path, str):
|
if not isinstance(path, str):
|
||||||
return path
|
return path
|
||||||
|
|
||||||
return web.FileResponse(path)
|
return web.FileResponse(path)
|
||||||
|
|
||||||
@routes.post("/userdata/{file}")
|
@routes.post("/userdata/{file}")
|
||||||
@@ -161,7 +211,7 @@ class UserManager():
|
|||||||
path = get_user_data_path(request)
|
path = get_user_data_path(request)
|
||||||
if not isinstance(path, str):
|
if not isinstance(path, str):
|
||||||
return path
|
return path
|
||||||
|
|
||||||
overwrite = request.query["overwrite"] != "false"
|
overwrite = request.query["overwrite"] != "false"
|
||||||
if not overwrite and os.path.exists(path):
|
if not overwrite and os.path.exists(path):
|
||||||
return web.Response(status=409)
|
return web.Response(status=409)
|
||||||
@@ -170,7 +220,7 @@ class UserManager():
|
|||||||
|
|
||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
f.write(body)
|
f.write(body)
|
||||||
|
|
||||||
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
|
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
|
||||||
return web.json_response(resp)
|
return web.json_response(resp)
|
||||||
|
|
||||||
@@ -181,7 +231,7 @@ class UserManager():
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
|
|
||||||
return web.Response(status=204)
|
return web.Response(status=204)
|
||||||
|
|
||||||
@routes.post("/userdata/{file}/move/{dest}")
|
@routes.post("/userdata/{file}/move/{dest}")
|
||||||
@@ -189,17 +239,17 @@ class UserManager():
|
|||||||
source = get_user_data_path(request, check_exists=True)
|
source = get_user_data_path(request, check_exists=True)
|
||||||
if not isinstance(source, str):
|
if not isinstance(source, str):
|
||||||
return source
|
return source
|
||||||
|
|
||||||
dest = get_user_data_path(request, check_exists=False, param="dest")
|
dest = get_user_data_path(request, check_exists=False, param="dest")
|
||||||
if not isinstance(source, str):
|
if not isinstance(source, str):
|
||||||
return dest
|
return dest
|
||||||
|
|
||||||
overwrite = request.query["overwrite"] != "false"
|
overwrite = request.query["overwrite"] != "false"
|
||||||
if not overwrite and os.path.exists(dest):
|
if not overwrite and os.path.exists(dest):
|
||||||
return web.Response(status=409)
|
return web.Response(status=409)
|
||||||
|
|
||||||
print(f"moving '{source}' -> '{dest}'")
|
print(f"moving '{source}' -> '{dest}'")
|
||||||
shutil.move(source, dest)
|
shutil.move(source, dest)
|
||||||
|
|
||||||
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
|
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
|
||||||
return web.json_response(resp)
|
return web.json_response(resp)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_blocks = None,
|
num_blocks = None,
|
||||||
|
control_latent_channels = None,
|
||||||
dtype = None,
|
dtype = None,
|
||||||
device = None,
|
device = None,
|
||||||
operations = None,
|
operations = None,
|
||||||
@@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
|||||||
for _ in range(len(self.joint_blocks)):
|
for _ in range(len(self.joint_blocks)):
|
||||||
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
if control_latent_channels is None:
|
||||||
|
control_latent_channels = self.in_channels
|
||||||
|
|
||||||
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
||||||
None,
|
None,
|
||||||
self.patch_size,
|
self.patch_size,
|
||||||
self.in_channels,
|
control_latent_channels,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
strict_img_size=False,
|
strict_img_size=False,
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class EnumAction(argparse.Action):
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
||||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||||
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
||||||
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
||||||
@@ -92,6 +92,8 @@ class LatentPreviewMethod(enum.Enum):
|
|||||||
|
|
||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
|
||||||
|
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||||
|
|
||||||
cache_group = parser.add_mutually_exclusive_group()
|
cache_group = parser.add_mutually_exclusive_group()
|
||||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
@@ -134,7 +136,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
|
|||||||
|
|
||||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||||
|
|
||||||
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||||
|
|
||||||
# The default built-in provider hosted under web/
|
# The default built-in provider hosted under web/
|
||||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||||
@@ -169,6 +171,8 @@ parser.add_argument(
|
|||||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -109,8 +109,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
if k not in u:
|
if k not in u:
|
||||||
t = sd.pop(k)
|
sd.pop(k)
|
||||||
del t
|
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load(ckpt_path):
|
def load(ckpt_path):
|
||||||
|
|||||||
@@ -79,13 +79,21 @@ class ControlBase:
|
|||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
self.extra_conds = []
|
self.extra_conds = []
|
||||||
self.strength_type = StrengthType.CONSTANT
|
self.strength_type = StrengthType.CONSTANT
|
||||||
|
self.concat_mask = False
|
||||||
|
self.extra_concat_orig = []
|
||||||
|
self.extra_concat = None
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.timestep_percent_range = timestep_percent_range
|
self.timestep_percent_range = timestep_percent_range
|
||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
|
if vae is None:
|
||||||
|
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
|
self.extra_concat_orig = extra_concat.copy()
|
||||||
|
if self.concat_mask and len(self.extra_concat_orig) == 0:
|
||||||
|
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
@@ -100,9 +108,9 @@ class ControlBase:
|
|||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
self.previous_controlnet.cleanup()
|
self.previous_controlnet.cleanup()
|
||||||
if self.cond_hint is not None:
|
|
||||||
del self.cond_hint
|
self.cond_hint = None
|
||||||
self.cond_hint = None
|
self.extra_concat = None
|
||||||
self.timestep_range = None
|
self.timestep_range = None
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
@@ -123,6 +131,8 @@ class ControlBase:
|
|||||||
c.vae = self.vae
|
c.vae = self.vae
|
||||||
c.extra_conds = self.extra_conds.copy()
|
c.extra_conds = self.extra_conds.copy()
|
||||||
c.strength_type = self.strength_type
|
c.strength_type = self.strength_type
|
||||||
|
c.concat_mask = self.concat_mask
|
||||||
|
c.extra_concat_orig = self.extra_concat_orig.copy()
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
@@ -175,7 +185,7 @@ class ControlBase:
|
|||||||
|
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
@@ -189,6 +199,7 @@ class ControlNet(ControlBase):
|
|||||||
self.latent_format = latent_format
|
self.latent_format = latent_format
|
||||||
self.extra_conds += extra_conds
|
self.extra_conds += extra_conds
|
||||||
self.strength_type = strength_type
|
self.strength_type = strength_type
|
||||||
|
self.concat_mask = concat_mask
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@@ -213,6 +224,9 @@ class ControlNet(ControlBase):
|
|||||||
compression_ratio = self.compression_ratio
|
compression_ratio = self.compression_ratio
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
compression_ratio *= self.vae.downscale_ratio
|
compression_ratio *= self.vae.downscale_ratio
|
||||||
|
else:
|
||||||
|
if self.latent_format is not None:
|
||||||
|
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
@@ -220,6 +234,13 @@ class ControlNet(ControlBase):
|
|||||||
comfy.model_management.load_models_gpu(loaded_models)
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
||||||
|
if len(self.extra_concat_orig) > 0:
|
||||||
|
to_concat = []
|
||||||
|
for c in self.extra_concat_orig:
|
||||||
|
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
||||||
|
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||||
|
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||||
|
|
||||||
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
@@ -319,7 +340,7 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
|
|
||||||
class ControlLora(ControlNet):
|
class ControlLora(ControlNet):
|
||||||
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options
|
||||||
ControlBase.__init__(self, device)
|
ControlBase.__init__(self, device)
|
||||||
self.control_weights = control_weights
|
self.control_weights = control_weights
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
@@ -376,19 +397,25 @@ class ControlLora(ControlNet):
|
|||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||||
|
|
||||||
def controlnet_config(sd):
|
def controlnet_config(sd, model_options={}):
|
||||||
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
||||||
|
|
||||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
unet_dtype = model_options.get("dtype", None)
|
||||||
|
if unet_dtype is None:
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
|
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||||
|
if weight_dtype is not None:
|
||||||
|
supported_inference_dtypes.append(weight_dtype)
|
||||||
|
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||||
|
|
||||||
controlnet_config = model_config.unet_config
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
if manual_cast_dtype is not None:
|
|
||||||
operations = comfy.ops.manual_cast
|
operations = model_options.get("custom_operations", None)
|
||||||
else:
|
if operations is None:
|
||||||
operations = comfy.ops.disable_weight_init
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||||
|
|
||||||
offload_device = comfy.model_management.unet_offload_device()
|
offload_device = comfy.model_management.unet_offload_device()
|
||||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||||
@@ -403,24 +430,29 @@ def controlnet_load_state_dict(control_model, sd):
|
|||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
return control_model
|
return control_model
|
||||||
|
|
||||||
def load_controlnet_mmdit(sd):
|
def load_controlnet_mmdit(sd, model_options={}):
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||||
for k in sd:
|
for k in sd:
|
||||||
new_sd[k] = sd[k]
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
concat_mask = False
|
||||||
|
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
|
||||||
|
if control_latent_channels == 17: #inpaint controlnet
|
||||||
|
concat_mask = True
|
||||||
|
|
||||||
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SD3()
|
latent_format = comfy.latent_formats.SD3()
|
||||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_hunyuandit(controlnet_data):
|
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
||||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||||
@@ -430,17 +462,17 @@ def load_controlnet_hunyuandit(controlnet_data):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet_flux_xlabs(sd):
|
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet_flux_instantx(sd):
|
def load_controlnet_flux_instantx(sd, model_options={}):
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||||
for k in sd:
|
for k in sd:
|
||||||
new_sd[k] = sd[k]
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
@@ -449,21 +481,30 @@ def load_controlnet_flux_instantx(sd):
|
|||||||
if union_cnet in new_sd:
|
if union_cnet in new_sd:
|
||||||
num_union_modes = new_sd[union_cnet].shape[0]
|
num_union_modes = new_sd[union_cnet].shape[0]
|
||||||
|
|
||||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
|
||||||
|
concat_mask = False
|
||||||
|
if control_latent_channels == 17:
|
||||||
|
concat_mask = True
|
||||||
|
|
||||||
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.Flux()
|
latent_format = comfy.latent_formats.Flux()
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def convert_mistoline(sd):
|
||||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||||
|
controlnet_data = state_dict
|
||||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||||
return load_controlnet_hunyuandit(controlnet_data)
|
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
if "lora_controlnet" in controlnet_data:
|
if "lora_controlnet" in controlnet_data:
|
||||||
return ControlLora(controlnet_data)
|
return ControlLora(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
controlnet_config = None
|
controlnet_config = None
|
||||||
supported_inference_dtypes = None
|
supported_inference_dtypes = None
|
||||||
@@ -518,13 +559,15 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if len(leftover_keys) > 0:
|
if len(leftover_keys) > 0:
|
||||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||||
controlnet_data = new_sd
|
controlnet_data = new_sd
|
||||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
elif "controlnet_blocks.0.weight" in controlnet_data:
|
||||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||||
return load_controlnet_flux_xlabs(controlnet_data)
|
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
||||||
elif "pos_embed_input.proj.weight" in controlnet_data:
|
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||||
return load_controlnet_mmdit(controlnet_data)
|
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data)
|
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||||
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||||
|
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
@@ -536,25 +579,36 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
elif key in controlnet_data:
|
elif key in controlnet_data:
|
||||||
prefix = ""
|
prefix = ""
|
||||||
else:
|
else:
|
||||||
net = load_t2i_adapter(controlnet_data)
|
net = load_t2i_adapter(controlnet_data, model_options=model_options)
|
||||||
if net is None:
|
if net is None:
|
||||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
logging.error("error could not detect control model type.")
|
||||||
return net
|
return net
|
||||||
|
|
||||||
if controlnet_config is None:
|
if controlnet_config is None:
|
||||||
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
||||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||||
controlnet_config = model_config.unet_config
|
controlnet_config = model_config.unet_config
|
||||||
|
|
||||||
|
unet_dtype = model_options.get("dtype", None)
|
||||||
|
if unet_dtype is None:
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
|
||||||
|
|
||||||
|
if supported_inference_dtypes is None:
|
||||||
|
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
|
||||||
|
|
||||||
|
if weight_dtype is not None:
|
||||||
|
supported_inference_dtypes.append(weight_dtype)
|
||||||
|
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||||
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
if supported_inference_dtypes is None:
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype()
|
|
||||||
else:
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
|
||||||
|
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
if manual_cast_dtype is not None:
|
operations = model_options.get("custom_operations", None)
|
||||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
if operations is None:
|
||||||
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
|
controlnet_config["operations"] = operations
|
||||||
controlnet_config["dtype"] = unet_dtype
|
controlnet_config["dtype"] = unet_dtype
|
||||||
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
@@ -590,14 +644,21 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
|
|
||||||
global_average_pooling = False
|
global_average_pooling = model_options.get("global_average_pooling", False)
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
|
||||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
|
||||||
global_average_pooling = True
|
|
||||||
|
|
||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||||
|
if "global_average_pooling" not in model_options:
|
||||||
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||||
|
model_options["global_average_pooling"] = True
|
||||||
|
|
||||||
|
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
|
||||||
|
if cnet is None:
|
||||||
|
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||||
|
return cnet
|
||||||
|
|
||||||
class T2IAdapter(ControlBase):
|
class T2IAdapter(ControlBase):
|
||||||
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
@@ -653,7 +714,7 @@ class T2IAdapter(ControlBase):
|
|||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def load_t2i_adapter(t2i_data):
|
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||||
compression_ratio = 8
|
compression_ratio = 8
|
||||||
upscale_algorithm = 'nearest-exact'
|
upscale_algorithm = 'nearest-exact'
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
|
||||||
|
|
||||||
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||||
mantissa_scaled = torch.where(
|
mantissa_scaled = torch.where(
|
||||||
@@ -41,9 +40,8 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
|||||||
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||||
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||||
)
|
)
|
||||||
del abs_x
|
|
||||||
|
|
||||||
return sign.to(dtype=dtype)
|
return sign
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -57,6 +55,11 @@ def stochastic_rounding(value, dtype, seed=0):
|
|||||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||||
generator = torch.Generator(device=value.device)
|
generator = torch.Generator(device=value.device)
|
||||||
generator.manual_seed(seed)
|
generator.manual_seed(seed)
|
||||||
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
|
output = torch.empty_like(value, dtype=dtype)
|
||||||
|
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||||
|
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||||
|
for i in range(0, value.shape[0], slice_size):
|
||||||
|
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
||||||
|
return output
|
||||||
|
|
||||||
return value.to(dtype=dtype)
|
return value.to(dtype=dtype)
|
||||||
|
|||||||
@@ -44,6 +44,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
|||||||
return append_zero(sigmas)
|
return append_zero(sigmas)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
|
||||||
|
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
|
||||||
|
epsilon = 1e-5 # avoid log(0)
|
||||||
|
x = torch.linspace(0, 1, n, device=device)
|
||||||
|
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
|
||||||
|
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
|
||||||
|
sigmas = clamp(torch.exp(lmb))
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def to_d(x, sigma, denoised):
|
def to_d(x, sigma, denoised):
|
||||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||||
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
||||||
@@ -1101,3 +1112,78 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
|||||||
if sigmas[i + 1] > 0:
|
if sigmas[i + 1] > 0:
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
return x
|
return x
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
|
temp = [0]
|
||||||
|
def post_cfg_function(args):
|
||||||
|
temp[0] = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
if sigma_down == 0:
|
||||||
|
# Euler method
|
||||||
|
d = to_d(x, sigmas[i], temp[0])
|
||||||
|
dt = sigma_down - sigmas[i]
|
||||||
|
x = denoised + d * sigma_down
|
||||||
|
else:
|
||||||
|
# DPM-Solver++(2S)
|
||||||
|
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
||||||
|
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
|
||||||
|
r = 1 / 2
|
||||||
|
h = t_next - t
|
||||||
|
s = t + r * h
|
||||||
|
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
|
||||||
|
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||||
|
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
|
||||||
|
# Noise addition
|
||||||
|
if sigmas[i + 1] > 0:
|
||||||
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
|
"""DPM-Solver++(2M)."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
|
|
||||||
|
old_uncond_denoised = None
|
||||||
|
uncond_denoised = None
|
||||||
|
def post_cfg_function(args):
|
||||||
|
nonlocal uncond_denoised
|
||||||
|
uncond_denoised = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||||
|
h = t_next - t
|
||||||
|
if old_uncond_denoised is None or sigmas[i + 1] == 0:
|
||||||
|
denoised_mix = -torch.exp(-h) * uncond_denoised
|
||||||
|
else:
|
||||||
|
h_last = t - t_fn(sigmas[i - 1])
|
||||||
|
r = h_last / h
|
||||||
|
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
|
||||||
|
x = denoised + denoised_mix + torch.exp(-h) * x
|
||||||
|
old_uncond_denoised = uncond_denoised
|
||||||
|
return x
|
||||||
@@ -4,6 +4,7 @@ class LatentFormat:
|
|||||||
scale_factor = 1.0
|
scale_factor = 1.0
|
||||||
latent_channels = 4
|
latent_channels = 4
|
||||||
latent_rgb_factors = None
|
latent_rgb_factors = None
|
||||||
|
latent_rgb_factors_bias = None
|
||||||
taesd_decoder_name = None
|
taesd_decoder_name = None
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
@@ -30,11 +31,13 @@ class SDXL(LatentFormat):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
# R G B
|
# R G B
|
||||||
[ 0.3920, 0.4054, 0.4549],
|
[ 0.3651, 0.4232, 0.4341],
|
||||||
[-0.2634, -0.0196, 0.0653],
|
[-0.2533, -0.0042, 0.1068],
|
||||||
[ 0.0568, 0.1687, -0.0755],
|
[ 0.1076, 0.1111, -0.0362],
|
||||||
[-0.3112, -0.2359, -0.2076]
|
[-0.3165, -0.2492, -0.2188]
|
||||||
]
|
]
|
||||||
|
self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011]
|
||||||
|
|
||||||
self.taesd_decoder_name = "taesdxl_decoder"
|
self.taesd_decoder_name = "taesdxl_decoder"
|
||||||
|
|
||||||
class SDXL_Playground_2_5(LatentFormat):
|
class SDXL_Playground_2_5(LatentFormat):
|
||||||
@@ -112,23 +115,24 @@ class SD3(LatentFormat):
|
|||||||
self.scale_factor = 1.5305
|
self.scale_factor = 1.5305
|
||||||
self.shift_factor = 0.0609
|
self.shift_factor = 0.0609
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
[-0.0645, 0.0177, 0.1052],
|
[-0.0922, -0.0175, 0.0749],
|
||||||
[ 0.0028, 0.0312, 0.0650],
|
[ 0.0311, 0.0633, 0.0954],
|
||||||
[ 0.1848, 0.0762, 0.0360],
|
[ 0.1994, 0.0927, 0.0458],
|
||||||
[ 0.0944, 0.0360, 0.0889],
|
[ 0.0856, 0.0339, 0.0902],
|
||||||
[ 0.0897, 0.0506, -0.0364],
|
[ 0.0587, 0.0272, -0.0496],
|
||||||
[-0.0020, 0.1203, 0.0284],
|
[-0.0006, 0.1104, 0.0309],
|
||||||
[ 0.0855, 0.0118, 0.0283],
|
[ 0.0978, 0.0306, 0.0427],
|
||||||
[-0.0539, 0.0658, 0.1047],
|
[-0.0042, 0.1038, 0.1358],
|
||||||
[-0.0057, 0.0116, 0.0700],
|
[-0.0194, 0.0020, 0.0669],
|
||||||
[-0.0412, 0.0281, -0.0039],
|
[-0.0488, 0.0130, -0.0268],
|
||||||
[ 0.1106, 0.1171, 0.1220],
|
[ 0.0922, 0.0988, 0.0951],
|
||||||
[-0.0248, 0.0682, -0.0481],
|
[-0.0278, 0.0524, -0.0542],
|
||||||
[ 0.0815, 0.0846, 0.1207],
|
[ 0.0332, 0.0456, 0.0895],
|
||||||
[-0.0120, -0.0055, -0.0867],
|
[-0.0069, -0.0030, -0.0810],
|
||||||
[-0.0749, -0.0634, -0.0456],
|
[-0.0596, -0.0465, -0.0293],
|
||||||
[-0.1418, -0.1457, -0.1259]
|
[-0.1448, -0.1463, -0.1189]
|
||||||
]
|
]
|
||||||
|
self.latent_rgb_factors_bias = [0.2394, 0.2135, 0.1925]
|
||||||
self.taesd_decoder_name = "taesd3_decoder"
|
self.taesd_decoder_name = "taesd3_decoder"
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
@@ -146,23 +150,24 @@ class Flux(SD3):
|
|||||||
self.scale_factor = 0.3611
|
self.scale_factor = 0.3611
|
||||||
self.shift_factor = 0.1159
|
self.shift_factor = 0.1159
|
||||||
self.latent_rgb_factors =[
|
self.latent_rgb_factors =[
|
||||||
[-0.0404, 0.0159, 0.0609],
|
[-0.0346, 0.0244, 0.0681],
|
||||||
[ 0.0043, 0.0298, 0.0850],
|
[ 0.0034, 0.0210, 0.0687],
|
||||||
[ 0.0328, -0.0749, -0.0503],
|
[ 0.0275, -0.0668, -0.0433],
|
||||||
[-0.0245, 0.0085, 0.0549],
|
[-0.0174, 0.0160, 0.0617],
|
||||||
[ 0.0966, 0.0894, 0.0530],
|
[ 0.0859, 0.0721, 0.0329],
|
||||||
[ 0.0035, 0.0399, 0.0123],
|
[ 0.0004, 0.0383, 0.0115],
|
||||||
[ 0.0583, 0.1184, 0.1262],
|
[ 0.0405, 0.0861, 0.0915],
|
||||||
[-0.0191, -0.0206, -0.0306],
|
[-0.0236, -0.0185, -0.0259],
|
||||||
[-0.0324, 0.0055, 0.1001],
|
[-0.0245, 0.0250, 0.1180],
|
||||||
[ 0.0955, 0.0659, -0.0545],
|
[ 0.1008, 0.0755, -0.0421],
|
||||||
[-0.0504, 0.0231, -0.0013],
|
[-0.0515, 0.0201, 0.0011],
|
||||||
[ 0.0500, -0.0008, -0.0088],
|
[ 0.0428, -0.0012, -0.0036],
|
||||||
[ 0.0982, 0.0941, 0.0976],
|
[ 0.0817, 0.0765, 0.0749],
|
||||||
[-0.1233, -0.0280, -0.0897],
|
[-0.1264, -0.0522, -0.1103],
|
||||||
[-0.0005, -0.0530, -0.0020],
|
[-0.0280, -0.0881, -0.0499],
|
||||||
[-0.1273, -0.0932, -0.0680]
|
[-0.1262, -0.0982, -0.0778]
|
||||||
]
|
]
|
||||||
|
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||||
self.taesd_decoder_name = "taef1_decoder"
|
self.taesd_decoder_name = "taef1_decoder"
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ except:
|
|||||||
rms_norm_torch = None
|
rms_norm_torch = None
|
||||||
|
|
||||||
def rms_norm(x, weight, eps=1e-6):
|
def rms_norm(x, weight, eps=1e-6):
|
||||||
if rms_norm_torch is not None:
|
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
else:
|
else:
|
||||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||||
|
#modified to support different types of flux controlnets
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
@@ -12,22 +13,65 @@ from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
|||||||
from .model import Flux
|
from .model import Flux
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
class MistolineCondDownsamplBlock(nn.Module):
|
||||||
|
def __init__(self, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.encoder(x)
|
||||||
|
|
||||||
|
class MistolineControlnetBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.act(self.linear(x))
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFlux(Flux):
|
class ControlNetFlux(Flux):
|
||||||
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
self.main_model_double = 19
|
self.main_model_double = 19
|
||||||
self.main_model_single = 38
|
self.main_model_single = 38
|
||||||
|
|
||||||
|
self.mistoline = mistoline
|
||||||
# add ControlNet blocks
|
# add ControlNet blocks
|
||||||
|
if self.mistoline:
|
||||||
|
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.controlnet_blocks = nn.ModuleList([])
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
for _ in range(self.params.depth):
|
for _ in range(self.params.depth):
|
||||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
self.controlnet_blocks.append(control_block())
|
||||||
self.controlnet_blocks.append(controlnet_block)
|
|
||||||
|
|
||||||
self.controlnet_single_blocks = nn.ModuleList([])
|
self.controlnet_single_blocks = nn.ModuleList([])
|
||||||
for _ in range(self.params.depth_single_blocks):
|
for _ in range(self.params.depth_single_blocks):
|
||||||
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
|
self.controlnet_single_blocks.append(control_block())
|
||||||
|
|
||||||
self.num_union_modes = num_union_modes
|
self.num_union_modes = num_union_modes
|
||||||
self.controlnet_mode_embedder = None
|
self.controlnet_mode_embedder = None
|
||||||
@@ -36,25 +80,33 @@ class ControlNetFlux(Flux):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
self.latent_input = latent_input
|
self.latent_input = latent_input
|
||||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
if control_latent_channels is None:
|
||||||
|
control_latent_channels = self.in_channels
|
||||||
|
else:
|
||||||
|
control_latent_channels *= 2 * 2 #patch size
|
||||||
|
|
||||||
|
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
if not self.latent_input:
|
if not self.latent_input:
|
||||||
self.input_hint_block = nn.Sequential(
|
if self.mistoline:
|
||||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
|
||||||
nn.SiLU(),
|
else:
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
self.input_hint_block = nn.Sequential(
|
||||||
nn.SiLU(),
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
nn.SiLU(),
|
||||||
)
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
@@ -73,9 +125,6 @@ class ControlNetFlux(Flux):
|
|||||||
|
|
||||||
# running on sequences img
|
# running on sequences img
|
||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
if not self.latent_input:
|
|
||||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
|
||||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
||||||
|
|
||||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||||
img = img + controlnet_cond
|
img = img + controlnet_cond
|
||||||
@@ -131,9 +180,14 @@ class ControlNetFlux(Flux):
|
|||||||
patch_size = 2
|
patch_size = 2
|
||||||
if self.latent_input:
|
if self.latent_input:
|
||||||
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||||
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
elif self.mistoline:
|
||||||
|
hint = hint * 2.0 - 1.0
|
||||||
|
hint = self.input_cond_block(hint)
|
||||||
else:
|
else:
|
||||||
hint = hint * 2.0 - 1.0
|
hint = hint * 2.0 - 1.0
|
||||||
|
hint = self.input_hint_block(hint)
|
||||||
|
|
||||||
|
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ class Flux(nn.Module):
|
|||||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||||
|
|
||||||
vec = vec + self.vector_in(y)
|
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
@@ -151,8 +151,8 @@ class Flux(nn.Module):
|
|||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
|||||||
@@ -842,6 +842,11 @@ class UNetModel(nn.Module):
|
|||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
|
if "emb_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["emb_patch"]
|
||||||
|
for p in patch:
|
||||||
|
emb = p(emb, self.model_channels, transformer_options)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
assert y.shape[0] == x.shape[0]
|
assert y.shape[0] == x.shape[0]
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|||||||
@@ -201,9 +201,13 @@ def load_lora(lora, to_load):
|
|||||||
|
|
||||||
def model_lora_keys_clip(model, key_map={}):
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
sdk = model.state_dict().keys()
|
sdk = model.state_dict().keys()
|
||||||
|
for k in sdk:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||||
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
clip_l_present = False
|
clip_l_present = False
|
||||||
|
clip_g_present = False
|
||||||
for b in range(32): #TODO: clean up
|
for b in range(32): #TODO: clean up
|
||||||
for c in LORA_CLIP_MAP:
|
for c in LORA_CLIP_MAP:
|
||||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
@@ -227,6 +231,7 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
|
|
||||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
|
clip_g_present = True
|
||||||
if clip_l_present:
|
if clip_l_present:
|
||||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
@@ -242,10 +247,18 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
|
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
|
||||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
t5_index = 1
|
||||||
key_map[lora_key] = k
|
if clip_g_present:
|
||||||
|
t5_index += 1
|
||||||
|
if clip_l_present:
|
||||||
|
t5_index += 1
|
||||||
|
if t5_index == 2:
|
||||||
|
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
|
||||||
|
t5_index += 1
|
||||||
|
|
||||||
|
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
|
||||||
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||||
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||||
@@ -281,6 +294,7 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||||
|
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
|
||||||
|
|
||||||
diffusers_lora_prefix = ["", "unet."]
|
diffusers_lora_prefix = ["", "unet."]
|
||||||
for p in diffusers_lora_prefix:
|
for p in diffusers_lora_prefix:
|
||||||
@@ -324,14 +338,15 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
to = diffusers_keys[k]
|
to = diffusers_keys[k]
|
||||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||||
|
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
lora_diff *= alpha
|
lora_diff *= alpha
|
||||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||||
weight_norm = (
|
weight_norm = (
|
||||||
weight_calc.transpose(0, 1)
|
weight_calc.transpose(0, 1)
|
||||||
.reshape(weight_calc.shape[1], -1)
|
.reshape(weight_calc.shape[1], -1)
|
||||||
@@ -400,7 +415,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
weight *= strength_model
|
weight *= strength_model
|
||||||
|
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
v = (calculate_weight(v[1:], comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype, copy=True), key, intermediate_dtype=intermediate_dtype), )
|
||||||
|
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
patch_type = "diff"
|
patch_type = "diff"
|
||||||
@@ -438,7 +453,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
try:
|
try:
|
||||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -484,7 +499,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
try:
|
try:
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -521,28 +536,48 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
try:
|
try:
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "glora":
|
elif patch_type == "glora":
|
||||||
if v[4] is not None:
|
|
||||||
alpha = v[4] / v[0].shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
dora_scale = v[5]
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
old_glora = False
|
||||||
|
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||||
|
rank = v[0].shape[0]
|
||||||
|
old_glora = True
|
||||||
|
|
||||||
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||||
|
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
old_glora = False
|
||||||
|
rank = v[1].shape[0]
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / rank
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape)
|
if old_glora:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||||
|
else:
|
||||||
|
if weight.dim() > 2:
|
||||||
|
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
else:
|
||||||
|
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||||
|
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=model_config.optimizations.get("fp8", False))
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
|||||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info("pytorch version: {}".format(torch.version.__version__))
|
logging.info("pytorch version: {}".format(torch_version))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -326,7 +326,7 @@ class LoadedModel:
|
|||||||
self.model_unload()
|
self.model_unload()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
|
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||||
|
|
||||||
@@ -426,7 +426,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
shift_model = current_loaded_models[i]
|
shift_model = current_loaded_models[i]
|
||||||
if shift_model.device == device:
|
if shift_model.device == device:
|
||||||
if shift_model not in keep_loaded:
|
if shift_model not in keep_loaded:
|
||||||
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
shift_model.currently_used = False
|
shift_model.currently_used = False
|
||||||
|
|
||||||
for x in sorted(can_unload):
|
for x in sorted(can_unload):
|
||||||
@@ -626,6 +626,8 @@ def maximum_vram_for_weights(device=None):
|
|||||||
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
||||||
|
|
||||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||||
|
if model_params < 0:
|
||||||
|
model_params = 1000000000000000000000
|
||||||
if args.bf16_unet:
|
if args.bf16_unet:
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
if args.fp16_unet:
|
if args.fp16_unet:
|
||||||
@@ -897,7 +899,7 @@ def force_upcast_attention_dtype():
|
|||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
try:
|
try:
|
||||||
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||||
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
|
if (14, 5) <= macos_version <= (15, 0, 1): # black image bug on recent versions of macOS
|
||||||
upcast = True
|
upcast = True
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@@ -1063,6 +1065,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def supports_fp8_compute(device=None):
|
def supports_fp8_compute(device=None):
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties(device)
|
props = torch.cuda.get_device_properties(device)
|
||||||
if props.major >= 9:
|
if props.major >= 9:
|
||||||
return True
|
return True
|
||||||
@@ -1070,6 +1075,14 @@ def supports_fp8_compute(device=None):
|
|||||||
return False
|
return False
|
||||||
if props.minor < 9:
|
if props.minor < 9:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if WINDOWS:
|
||||||
|
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ import comfy.utils
|
|||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
from comfy.types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
|
|
||||||
def string_to_seed(data):
|
def string_to_seed(data):
|
||||||
crc = 0xFFFFFFFF
|
crc = 0xFFFFFFFF
|
||||||
@@ -88,8 +88,12 @@ class LowVramPatch:
|
|||||||
self.key = key
|
self.key = key
|
||||||
self.patches = patches
|
self.patches = patches
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
intermediate_dtype = weight.dtype
|
||||||
|
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||||
|
intermediate_dtype = torch.float32
|
||||||
|
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
|
||||||
|
|
||||||
|
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
@@ -283,17 +287,21 @@ class ModelPatcher:
|
|||||||
return list(p)
|
return list(p)
|
||||||
|
|
||||||
def get_key_patches(self, filter_prefix=None):
|
def get_key_patches(self, filter_prefix=None):
|
||||||
comfy.model_management.unload_model_clones(self)
|
|
||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
p = {}
|
p = {}
|
||||||
for k in model_sd:
|
for k in model_sd:
|
||||||
if filter_prefix is not None:
|
if filter_prefix is not None:
|
||||||
if not k.startswith(filter_prefix):
|
if not k.startswith(filter_prefix):
|
||||||
continue
|
continue
|
||||||
if k in self.patches:
|
bk = self.backup.get(k, None)
|
||||||
p[k] = [model_sd[k]] + self.patches[k]
|
if bk is not None:
|
||||||
|
weight = bk.weight
|
||||||
else:
|
else:
|
||||||
p[k] = (model_sd[k],)
|
weight = model_sd[k]
|
||||||
|
if k in self.patches:
|
||||||
|
p[k] = [weight] + self.patches[k]
|
||||||
|
else:
|
||||||
|
p[k] = (weight,)
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def model_state_dict(self, filter_prefix=None):
|
def model_state_dict(self, filter_prefix=None):
|
||||||
|
|||||||
@@ -260,7 +260,6 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||||
w = w.t()
|
w = w.t()
|
||||||
|
|
||||||
@@ -300,10 +299,14 @@ class fp8_ops(manual_cast):
|
|||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False):
|
||||||
|
if comfy.model_management.supports_fp8_compute(load_device):
|
||||||
|
if (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
||||||
|
return fp8_ops
|
||||||
|
|
||||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
return disable_weight_init
|
return disable_weight_init
|
||||||
if args.fast:
|
if args.fast and not disable_fast_fp8:
|
||||||
if comfy.model_management.supports_fp8_compute(load_device):
|
if comfy.model_management.supports_fp8_compute(load_device):
|
||||||
return fp8_ops
|
return fp8_ops
|
||||||
return manual_cast
|
return manual_cast
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from comfy import model_management
|
|||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
import comfy.sampler_helpers
|
import comfy.sampler_helpers
|
||||||
import scipy
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
def get_area_and_mult(conds, x_in, timestep_in):
|
def get_area_and_mult(conds, x_in, timestep_in):
|
||||||
@@ -570,8 +570,8 @@ class Sampler:
|
|||||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis"]
|
"ipndm", "ipndm_v", "deis"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
|
|||||||
79
comfy/sd.py
79
comfy/sd.py
@@ -29,7 +29,6 @@ import comfy.text_encoders.long_clipl
|
|||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
import comfy.supported_models_base
|
|
||||||
import comfy.taesd.taesd
|
import comfy.taesd.taesd
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||||
@@ -70,14 +69,14 @@ class CLIP:
|
|||||||
clip = target.clip
|
clip = target.clip
|
||||||
tokenizer = target.tokenizer
|
tokenizer = target.tokenizer
|
||||||
|
|
||||||
load_device = model_management.text_encoder_device()
|
load_device = model_options.get("load_device", model_management.text_encoder_device())
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_options.get("offload_device", model_management.text_encoder_offload_device())
|
||||||
dtype = model_options.get("dtype", None)
|
dtype = model_options.get("dtype", None)
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = model_management.text_encoder_dtype(load_device)
|
dtype = model_management.text_encoder_dtype(load_device)
|
||||||
|
|
||||||
params['dtype'] = dtype
|
params['dtype'] = dtype
|
||||||
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
params['device'] = model_options.get("initial_device", model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)))
|
||||||
params['model_options'] = model_options
|
params['model_options'] = model_options
|
||||||
|
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
@@ -348,7 +347,7 @@ class VAE:
|
|||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / max(1, memory_used))
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
|
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
@@ -406,8 +405,35 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class TEModel(Enum):
|
||||||
|
CLIP_L = 1
|
||||||
|
CLIP_H = 2
|
||||||
|
CLIP_G = 3
|
||||||
|
T5_XXL = 4
|
||||||
|
T5_XL = 5
|
||||||
|
T5_BASE = 6
|
||||||
|
|
||||||
|
def detect_te_model(sd):
|
||||||
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
|
return TEModel.CLIP_G
|
||||||
|
if "text_model.encoder.layers.22.mlp.fc1.weight" in sd:
|
||||||
|
return TEModel.CLIP_H
|
||||||
|
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
||||||
|
return TEModel.CLIP_L
|
||||||
|
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||||
|
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||||
|
if weight.shape[-1] == 4096:
|
||||||
|
return TEModel.T5_XXL
|
||||||
|
elif weight.shape[-1] == 2048:
|
||||||
|
return TEModel.T5_XL
|
||||||
|
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||||
|
return TEModel.T5_BASE
|
||||||
|
return None
|
||||||
|
|
||||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = state_dicts
|
clip_data = state_dicts
|
||||||
|
|
||||||
class EmptyClass:
|
class EmptyClass:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -421,39 +447,42 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target = EmptyClass()
|
clip_target = EmptyClass()
|
||||||
clip_target.params = {}
|
clip_target.params = {}
|
||||||
if len(clip_data) == 1:
|
if len(clip_data) == 1:
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
te_model = detect_te_model(clip_data[0])
|
||||||
|
if te_model == TEModel.CLIP_G:
|
||||||
if clip_type == CLIPType.STABLE_CASCADE:
|
if clip_type == CLIPType.STABLE_CASCADE:
|
||||||
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
||||||
|
elif clip_type == CLIPType.SD3:
|
||||||
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
elif te_model == TEModel.CLIP_H:
|
||||||
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
||||||
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
|
elif te_model == TEModel.T5_XXL:
|
||||||
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||||
dtype_t5 = weight.dtype
|
dtype_t5 = weight.dtype
|
||||||
if weight.shape[-1] == 4096:
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
elif te_model == TEModel.T5_XL:
|
||||||
elif weight.shape[-1] == 2048:
|
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
elif te_model == TEModel.T5_BASE:
|
||||||
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
|
||||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||||
else:
|
else:
|
||||||
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
|
if clip_type == CLIPType.SD3:
|
||||||
if w is not None and w.shape[0] == 248:
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||||
clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer
|
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
elif len(clip_data) == 2:
|
elif len(clip_data) == 2:
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
|
||||||
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
elif clip_type == CLIPType.HUNYUAN_DIT:
|
elif clip_type == CLIPType.HUNYUAN_DIT:
|
||||||
clip_target.clip = comfy.text_encoders.hydit.HyditModel
|
clip_target.clip = comfy.text_encoders.hydit.HyditModel
|
||||||
@@ -475,10 +504,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
|
||||||
parameters = 0
|
parameters = 0
|
||||||
|
tokenizer_data = {}
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
parameters += comfy.utils.calculate_parameters(c)
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
|
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
m, u = clip.load_sd(c)
|
m, u = clip.load_sd(c)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
@@ -562,7 +593,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
|
|
||||||
if output_model:
|
if output_model:
|
||||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||||
offload_device = model_management.unet_offload_device()
|
|
||||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||||
model.load_model_weights(sd, diffusion_model_prefix)
|
model.load_model_weights(sd, diffusion_model_prefix)
|
||||||
|
|
||||||
@@ -647,7 +677,10 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
|||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||||
|
if model_options.get("fp8_optimizations", False):
|
||||||
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
|
|||||||
@@ -542,6 +542,7 @@ class SD1Tokenizer:
|
|||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
||||||
self.clip_name = clip_name
|
self.clip_name = clip_name
|
||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
||||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@@ -570,6 +571,7 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
self.clip_name = clip_name
|
self.clip_name = clip_name
|
||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
|
||||||
|
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
|
||||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
||||||
|
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
|||||||
|
|
||||||
class SDXLTokenizer:
|
class SDXLTokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@@ -40,7 +41,8 @@ class SDXLTokenizer:
|
|||||||
class SDXLClipModel(torch.nn.Module):
|
class SDXLClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
|
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||||
self.dtypes = set([dtype])
|
self.dtypes = set([dtype])
|
||||||
|
|
||||||
@@ -57,7 +59,8 @@ class SDXLClipModel(torch.nn.Module):
|
|||||||
token_weight_pairs_l = token_weight_pairs["l"]
|
token_weight_pairs_l = token_weight_pairs["l"]
|
||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
return torch.cat([l_out, g_out], dim=-1), g_pooled
|
cut_to = min(l_out.shape[1], g_out.shape[1])
|
||||||
|
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class BASE:
|
|||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
custom_operations = None
|
custom_operations = None
|
||||||
|
optimizations = {"fp8": False}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def matches(s, unet_config, state_dict=None):
|
def matches(s, unet_config, state_dict=None):
|
||||||
|
|||||||
@@ -13,12 +13,13 @@ class T5XXLModel(sd1_clip.SDClipModel):
|
|||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||||
|
|
||||||
|
|
||||||
class FluxTokenizer:
|
class FluxTokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@@ -38,7 +39,8 @@ class FluxClipModel(torch.nn.Module):
|
|||||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
|
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||||
self.dtypes = set([dtype, dtype_t5])
|
self.dtypes = set([dtype, dtype_t5])
|
||||||
|
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
|||||||
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
class LongClipModel_(sd1_clip.SDClipModel):
|
class LongClipModel_(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, *args, **kwargs):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
||||||
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
|
||||||
|
|
||||||
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@@ -17,3 +17,14 @@ class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
class LongClipModel(sd1_clip.SD1ClipModel):
|
class LongClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
||||||
|
|
||||||
|
def model_options_long_clip(sd, tokenizer_data, model_options):
|
||||||
|
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
|
||||||
|
if w is None:
|
||||||
|
w = sd.get("text_model.embeddings.position_embedding.weight", None)
|
||||||
|
if w is not None and w.shape[0] == 248:
|
||||||
|
tokenizer_data = tokenizer_data.copy()
|
||||||
|
model_options = model_options.copy()
|
||||||
|
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
|
||||||
|
model_options["clip_l_class"] = LongClipModel_
|
||||||
|
return tokenizer_data, model_options
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
|||||||
|
|
||||||
class SD3Tokenizer:
|
class SD3Tokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
@@ -42,7 +43,8 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if clip_l:
|
if clip_l:
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
|
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||||
self.dtypes.add(dtype)
|
self.dtypes.add(dtype)
|
||||||
else:
|
else:
|
||||||
self.clip_l = None
|
self.clip_l = None
|
||||||
@@ -95,7 +97,8 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
if self.clip_g is not None:
|
if self.clip_g is not None:
|
||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
if lg_out is not None:
|
if lg_out is not None:
|
||||||
lg_out = torch.cat([lg_out, g_out], dim=-1)
|
cut_to = min(lg_out.shape[1], g_out.shape[1])
|
||||||
|
lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
|
||||||
else:
|
else:
|
||||||
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -713,7 +713,9 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
|||||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||||
|
|
||||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
|
||||||
|
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
|
||||||
|
return rows * cols
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
@@ -722,10 +724,20 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
|||||||
|
|
||||||
for b in range(samples.shape[0]):
|
for b in range(samples.shape[0]):
|
||||||
s = samples[b:b+1]
|
s = samples[b:b+1]
|
||||||
|
|
||||||
|
# handle entire input fitting in a single tile
|
||||||
|
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
|
||||||
|
output[b:b+1] = function(s).to(output_device)
|
||||||
|
if pbar is not None:
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||||
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||||
|
|
||||||
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
positions = [range(0, s.shape[d+2], tile[d] - overlap) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||||
|
|
||||||
|
for it in itertools.product(*positions):
|
||||||
s_in = s
|
s_in = s
|
||||||
upscaled = []
|
upscaled = []
|
||||||
|
|
||||||
@@ -734,15 +746,16 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
|||||||
l = min(tile[d], s.shape[d + 2] - pos)
|
l = min(tile[d], s.shape[d + 2] - pos)
|
||||||
s_in = s_in.narrow(d + 2, pos, l)
|
s_in = s_in.narrow(d + 2, pos, l)
|
||||||
upscaled.append(round(pos * upscale_amount))
|
upscaled.append(round(pos * upscale_amount))
|
||||||
|
|
||||||
ps = function(s_in).to(output_device)
|
ps = function(s_in).to(output_device)
|
||||||
mask = torch.ones_like(ps)
|
mask = torch.ones_like(ps)
|
||||||
feather = round(overlap * upscale_amount)
|
feather = round(overlap * upscale_amount)
|
||||||
|
|
||||||
for t in range(feather):
|
for t in range(feather):
|
||||||
for d in range(2, dims + 2):
|
for d in range(2, dims + 2):
|
||||||
m = mask.narrow(d, t, 1)
|
a = (t + 1) / feather
|
||||||
m *= ((1.0/feather) * (t + 1))
|
mask.narrow(d, t, 1).mul_(a)
|
||||||
m = mask.narrow(d, mask.shape[d] -1 -t, 1)
|
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
||||||
m *= ((1.0/feather) * (t + 1))
|
|
||||||
|
|
||||||
o = out
|
o = out
|
||||||
o_d = out_div
|
o_d = out_div
|
||||||
@@ -750,8 +763,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
|||||||
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
|
|
||||||
o += ps * mask
|
o.add_(ps * mask)
|
||||||
o_d += mask
|
o_d.add_(mask)
|
||||||
|
|
||||||
if pbar is not None:
|
if pbar is not None:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|||||||
@@ -1,11 +1,21 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from typing import Sequence, Mapping
|
from typing import Sequence, Mapping, Dict
|
||||||
from comfy_execution.graph import DynamicPrompt
|
from comfy_execution.graph import DynamicPrompt
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
|
|
||||||
from comfy_execution.graph_utils import is_link
|
from comfy_execution.graph_utils import is_link
|
||||||
|
|
||||||
|
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def include_unique_id_in_input(class_type: str) -> bool:
|
||||||
|
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
||||||
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
||||||
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
|
|
||||||
class CacheKeySet:
|
class CacheKeySet:
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
self.keys = {}
|
self.keys = {}
|
||||||
@@ -98,7 +108,7 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
class_type = node["class_type"]
|
class_type = node["class_type"]
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
|
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||||
signature.append(node_id)
|
signature.append(node_id)
|
||||||
inputs = node["inputs"]
|
inputs = node["inputs"]
|
||||||
for key in sorted(inputs.keys()):
|
for key in sorted(inputs.keys()):
|
||||||
|
|||||||
@@ -99,30 +99,44 @@ class TopologicalSort:
|
|||||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||||
|
|
||||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||||
self.add_node(from_node_id)
|
if not self.is_cached(from_node_id):
|
||||||
if to_node_id not in self.blocking[from_node_id]:
|
self.add_node(from_node_id)
|
||||||
self.blocking[from_node_id][to_node_id] = {}
|
if to_node_id not in self.blocking[from_node_id]:
|
||||||
self.blockCount[to_node_id] += 1
|
self.blocking[from_node_id][to_node_id] = {}
|
||||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
self.blockCount[to_node_id] += 1
|
||||||
|
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||||
|
|
||||||
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
|
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
|
||||||
if unique_id in self.pendingNodes:
|
node_ids = [node_unique_id]
|
||||||
return
|
links = []
|
||||||
self.pendingNodes[unique_id] = True
|
|
||||||
self.blockCount[unique_id] = 0
|
|
||||||
self.blocking[unique_id] = {}
|
|
||||||
|
|
||||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
while len(node_ids) > 0:
|
||||||
for input_name in inputs:
|
unique_id = node_ids.pop()
|
||||||
value = inputs[input_name]
|
if unique_id in self.pendingNodes:
|
||||||
if is_link(value):
|
continue
|
||||||
from_node_id, from_socket = value
|
|
||||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
self.pendingNodes[unique_id] = True
|
||||||
continue
|
self.blockCount[unique_id] = 0
|
||||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
self.blocking[unique_id] = {}
|
||||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
|
||||||
if include_lazy or not is_lazy:
|
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||||
self.add_strong_link(from_node_id, from_socket, unique_id)
|
for input_name in inputs:
|
||||||
|
value = inputs[input_name]
|
||||||
|
if is_link(value):
|
||||||
|
from_node_id, from_socket = value
|
||||||
|
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||||
|
continue
|
||||||
|
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||||
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||||
|
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||||
|
node_ids.append(from_node_id)
|
||||||
|
links.append((from_node_id, from_socket, unique_id))
|
||||||
|
|
||||||
|
for link in links:
|
||||||
|
self.add_strong_link(*link)
|
||||||
|
|
||||||
|
def is_cached(self, node_id):
|
||||||
|
return False
|
||||||
|
|
||||||
def get_ready_nodes(self):
|
def get_ready_nodes(self):
|
||||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||||
@@ -146,11 +160,8 @@ class ExecutionList(TopologicalSort):
|
|||||||
self.output_cache = output_cache
|
self.output_cache = output_cache
|
||||||
self.staged_node_id = None
|
self.staged_node_id = None
|
||||||
|
|
||||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
def is_cached(self, node_id):
|
||||||
if self.output_cache.get(from_node_id) is not None:
|
return self.output_cache.get(node_id) is not None
|
||||||
# Nothing to do
|
|
||||||
return
|
|
||||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
|
||||||
|
|
||||||
def stage_node_execution(self):
|
def stage_node_execution(self):
|
||||||
assert self.staged_node_id is None
|
assert self.staged_node_id is None
|
||||||
|
|||||||
@@ -16,14 +16,15 @@ class EmptyLatentAudio:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}}
|
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||||
|
}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
CATEGORY = "latent/audio"
|
CATEGORY = "latent/audio"
|
||||||
|
|
||||||
def generate(self, seconds):
|
def generate(self, seconds, batch_size):
|
||||||
batch_size = 1
|
|
||||||
length = round((seconds * 44100 / 2048) / 2) * 2
|
length = round((seconds * 44100 / 2048) / 2) * 2
|
||||||
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
||||||
return ({"samples":latent, "type": "audio"}, )
|
return ({"samples":latent, "type": "audio"}, )
|
||||||
@@ -58,6 +59,9 @@ class VAEDecodeAudio:
|
|||||||
|
|
||||||
def decode(self, vae, samples):
|
def decode(self, vae, samples):
|
||||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||||
|
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||||
|
std[std < 1.0] = 1.0
|
||||||
|
audio /= std
|
||||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||||
|
|
||||||
|
|
||||||
@@ -183,17 +187,10 @@ class PreviewAudio(SaveAudio):
|
|||||||
}
|
}
|
||||||
|
|
||||||
class LoadAudio:
|
class LoadAudio:
|
||||||
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
files = [
|
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
||||||
f for f in os.listdir(input_dir)
|
|
||||||
if (os.path.isfile(os.path.join(input_dir, f))
|
|
||||||
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
||||||
|
|
||||||
CATEGORY = "audio"
|
CATEGORY = "audio"
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||||
|
import nodes
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class SetUnionControlNetType:
|
class SetUnionControlNetType:
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -22,6 +24,37 @@ class SetUnionControlNetType:
|
|||||||
|
|
||||||
return (control_net,)
|
return (control_net,)
|
||||||
|
|
||||||
|
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"control_net": ("CONTROL_NET", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"image": ("IMAGE", ),
|
||||||
|
"mask": ("MASK", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
}}
|
||||||
|
|
||||||
|
FUNCTION = "apply_inpaint_controlnet"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
|
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
|
||||||
|
extra_concat = []
|
||||||
|
if control_net.concat_mask:
|
||||||
|
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||||
|
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
|
||||||
|
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
|
||||||
|
extra_concat = [mask]
|
||||||
|
|
||||||
|
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"SetUnionControlNetType": SetUnionControlNetType,
|
"SetUnionControlNetType": SetUnionControlNetType,
|
||||||
|
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -90,6 +90,27 @@ class PolyexponentialScheduler:
|
|||||||
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
||||||
return (sigmas, )
|
return (sigmas, )
|
||||||
|
|
||||||
|
class LaplaceScheduler:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
|
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
|
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
|
"mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
|
||||||
|
"beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("SIGMAS",)
|
||||||
|
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||||
|
|
||||||
|
FUNCTION = "get_sigmas"
|
||||||
|
|
||||||
|
def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
|
||||||
|
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
|
||||||
|
return (sigmas, )
|
||||||
|
|
||||||
|
|
||||||
class SDTurboScheduler:
|
class SDTurboScheduler:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@@ -673,6 +694,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"KarrasScheduler": KarrasScheduler,
|
"KarrasScheduler": KarrasScheduler,
|
||||||
"ExponentialScheduler": ExponentialScheduler,
|
"ExponentialScheduler": ExponentialScheduler,
|
||||||
"PolyexponentialScheduler": PolyexponentialScheduler,
|
"PolyexponentialScheduler": PolyexponentialScheduler,
|
||||||
|
"LaplaceScheduler": LaplaceScheduler,
|
||||||
"VPScheduler": VPScheduler,
|
"VPScheduler": VPScheduler,
|
||||||
"BetaSamplingScheduler": BetaSamplingScheduler,
|
"BetaSamplingScheduler": BetaSamplingScheduler,
|
||||||
"SDTurboScheduler": SDTurboScheduler,
|
"SDTurboScheduler": SDTurboScheduler,
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ class HypernetworkLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||||||
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
|
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||||
model_hypernetwork = model.clone()
|
model_hypernetwork = model.clone()
|
||||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||||
if patch is not None:
|
if patch is not None:
|
||||||
|
|||||||
115
comfy_extras/nodes_lora_extract.py
Normal file
115
comfy_extras/nodes_lora_extract.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
|
def extract_lora(diff, rank):
|
||||||
|
conv2d = (len(diff.shape) == 4)
|
||||||
|
kernel_size = None if not conv2d else diff.size()[2:4]
|
||||||
|
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||||
|
out_dim, in_dim = diff.size()[0:2]
|
||||||
|
rank = min(rank, in_dim, out_dim)
|
||||||
|
|
||||||
|
if conv2d:
|
||||||
|
if conv2d_3x3:
|
||||||
|
diff = diff.flatten(start_dim=1)
|
||||||
|
else:
|
||||||
|
diff = diff.squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
U, S, Vh = torch.linalg.svd(diff.float())
|
||||||
|
U = U[:, :rank]
|
||||||
|
S = S[:rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
Vh = Vh[:rank, :]
|
||||||
|
|
||||||
|
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
|
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||||
|
low_val = -hi_val
|
||||||
|
|
||||||
|
U = U.clamp(low_val, hi_val)
|
||||||
|
Vh = Vh.clamp(low_val, hi_val)
|
||||||
|
if conv2d:
|
||||||
|
U = U.reshape(out_dim, rank, 1, 1)
|
||||||
|
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||||
|
return (U, Vh)
|
||||||
|
|
||||||
|
class LORAType(Enum):
|
||||||
|
STANDARD = 0
|
||||||
|
FULL_DIFF = 1
|
||||||
|
|
||||||
|
LORA_TYPES = {"standard": LORAType.STANDARD,
|
||||||
|
"full_diff": LORAType.FULL_DIFF}
|
||||||
|
|
||||||
|
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
|
||||||
|
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||||
|
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
||||||
|
|
||||||
|
for k in sd:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
weight_diff = sd[k]
|
||||||
|
if lora_type == LORAType.STANDARD:
|
||||||
|
if weight_diff.ndim < 2:
|
||||||
|
if bias_diff:
|
||||||
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
out = extract_lora(weight_diff, rank)
|
||||||
|
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
|
||||||
|
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
|
||||||
|
except:
|
||||||
|
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
|
||||||
|
elif lora_type == LORAType.FULL_DIFF:
|
||||||
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||||
|
|
||||||
|
elif bias_diff and k.endswith(".bias"):
|
||||||
|
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||||
|
return output_sd
|
||||||
|
|
||||||
|
class LoraSave:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
||||||
|
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
|
||||||
|
"lora_type": (tuple(LORA_TYPES.keys()),),
|
||||||
|
"bias_diff": ("BOOLEAN", {"default": True}),
|
||||||
|
},
|
||||||
|
"optional": {"model_diff": ("MODEL",),
|
||||||
|
"text_encoder_diff": ("CLIP",)},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
|
||||||
|
if model_diff is None and text_encoder_diff is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
lora_type = LORA_TYPES.get(lora_type)
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
|
|
||||||
|
output_sd = {}
|
||||||
|
if model_diff is not None:
|
||||||
|
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
|
||||||
|
if text_encoder_diff is not None:
|
||||||
|
output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
|
||||||
|
|
||||||
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"LoraSave": LoraSave
|
||||||
|
}
|
||||||
@@ -17,7 +17,7 @@ class PatchModelAddDownscale:
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "model_patches/unet"
|
||||||
|
|
||||||
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
||||||
model_sampling = model.get_model_object("model_sampling")
|
model_sampling = model.get_model_object("model_sampling")
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class PerpNeg:
|
|||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
def patch(self, model, empty_conditioning, neg_scale):
|
def patch(self, model, empty_conditioning, neg_scale):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ class PhotoMakerLoader:
|
|||||||
CATEGORY = "_for_testing/photomaker"
|
CATEGORY = "_for_testing/photomaker"
|
||||||
|
|
||||||
def load_photomaker_model(self, photomaker_model_name):
|
def load_photomaker_model(self, photomaker_model_name):
|
||||||
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
|
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
|
||||||
photomaker_model = PhotoMakerIDEncoder()
|
photomaker_model = PhotoMakerIDEncoder()
|
||||||
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
|
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
|
||||||
if "id_encoder" in data:
|
if "id_encoder" in data:
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ class TripleCLIPLoader:
|
|||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
||||||
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
|
||||||
clip_path3 = folder_paths.get_full_path("clip", clip_name3)
|
clip_path3 = folder_paths.get_full_path_or_raise("clip", clip_name3)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ class EmptySD3LatentImage:
|
|||||||
CATEGORY = "latent/sd3"
|
CATEGORY = "latent/sd3"
|
||||||
|
|
||||||
def generate(self, width, height, batch_size=1):
|
def generate(self, width, height, batch_size=1):
|
||||||
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609
|
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
|
||||||
return ({"samples":latent}, )
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
class CLIPTextEncodeSD3:
|
class CLIPTextEncodeSD3:
|
||||||
@@ -93,6 +93,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
|||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
}}
|
}}
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"TripleCLIPLoader": TripleCLIPLoader,
|
"TripleCLIPLoader": TripleCLIPLoader,
|
||||||
@@ -103,5 +104,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
# Sampling
|
# Sampling
|
||||||
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
|
"ControlNetApplySD3": "Apply Controlnet with VAE",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ class StableCascade_SuperResolutionControlnet:
|
|||||||
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
|
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
CATEGORY = "_for_testing/stable_cascade"
|
CATEGORY = "_for_testing/stable_cascade"
|
||||||
|
|
||||||
def generate(self, image, vae):
|
def generate(self, image, vae):
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ class TomePatchModel:
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "model_patches/unet"
|
||||||
|
|
||||||
def patch(self, model, ratio):
|
def patch(self, model, ratio):
|
||||||
self.u = None
|
self.u = None
|
||||||
|
|||||||
22
comfy_extras/nodes_torch_compile.py
Normal file
22
comfy_extras/nodes_torch_compile.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class TorchCompileModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"backend": (["inductor", "cudagraphs"],),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def patch(self, model, backend):
|
||||||
|
m = model.clone()
|
||||||
|
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TorchCompileModel": TorchCompileModel,
|
||||||
|
}
|
||||||
@@ -25,7 +25,7 @@ class UpscaleModelLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_model(self, model_name):
|
def load_model(self, model_name):
|
||||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
|
||||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||||
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class ImageOnlyCheckpointLoader:
|
|||||||
CATEGORY = "loaders/video_models"
|
CATEGORY = "loaders/video_models"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (out[0], out[3], out[2])
|
return (out[0], out[3], out[2])
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ class VideoTriangleCFGGuidance:
|
|||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
|
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ class SaveImageWebsocket:
|
|||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def IS_CHANGED(s, images):
|
def IS_CHANGED(s, images):
|
||||||
return time.time()
|
return time.time()
|
||||||
|
|
||||||
|
|||||||
@@ -179,7 +179,13 @@ def merge_result_data(results, obj):
|
|||||||
# merge node execution results
|
# merge node execution results
|
||||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||||
if is_list:
|
if is_list:
|
||||||
output.append([x for o in results for x in o[i]])
|
value = []
|
||||||
|
for o in results:
|
||||||
|
if isinstance(o[i], ExecutionBlocker):
|
||||||
|
value.append(o[i])
|
||||||
|
else:
|
||||||
|
value.extend(o[i])
|
||||||
|
output.append(value)
|
||||||
else:
|
else:
|
||||||
output.append([o[i] for o in results])
|
output.append([o[i] for o in results])
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -25,11 +25,16 @@ a111:
|
|||||||
|
|
||||||
#comfyui:
|
#comfyui:
|
||||||
# base_path: path/to/comfyui/
|
# base_path: path/to/comfyui/
|
||||||
|
# # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
|
||||||
|
# #is_default: true
|
||||||
# checkpoints: models/checkpoints/
|
# checkpoints: models/checkpoints/
|
||||||
# clip: models/clip/
|
# clip: models/clip/
|
||||||
# clip_vision: models/clip_vision/
|
# clip_vision: models/clip_vision/
|
||||||
# configs: models/configs/
|
# configs: models/configs/
|
||||||
# controlnet: models/controlnet/
|
# controlnet: models/controlnet/
|
||||||
|
# diffusion_models: |
|
||||||
|
# models/diffusion_models
|
||||||
|
# models/unet
|
||||||
# embeddings: models/embeddings/
|
# embeddings: models/embeddings/
|
||||||
# loras: models/loras/
|
# loras: models/loras/
|
||||||
# upscale_models: models/upscale_models/
|
# upscale_models: models/upscale_models/
|
||||||
|
|||||||
103
folder_paths.py
103
folder_paths.py
@@ -2,7 +2,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import mimetypes
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Set, List, Dict, Tuple, Literal
|
||||||
from collections.abc import Collection
|
from collections.abc import Collection
|
||||||
|
|
||||||
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||||
@@ -44,6 +46,40 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
|
|||||||
|
|
||||||
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||||
|
|
||||||
|
class CacheHelper:
|
||||||
|
"""
|
||||||
|
Helper class for managing file list cache data.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||||
|
self.active = False
|
||||||
|
|
||||||
|
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
|
||||||
|
if not self.active:
|
||||||
|
return default
|
||||||
|
return self.cache.get(key, default)
|
||||||
|
|
||||||
|
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
|
||||||
|
if self.active:
|
||||||
|
self.cache[key] = value
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.cache.clear()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.active = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.active = False
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
cache_helper = CacheHelper()
|
||||||
|
|
||||||
|
extension_mimetypes_cache = {
|
||||||
|
"webp" : "image",
|
||||||
|
}
|
||||||
|
|
||||||
def map_legacy(folder_name: str) -> str:
|
def map_legacy(folder_name: str) -> str:
|
||||||
legacy = {"unet": "diffusion_models"}
|
legacy = {"unet": "diffusion_models"}
|
||||||
return legacy.get(folder_name, folder_name)
|
return legacy.get(folder_name, folder_name)
|
||||||
@@ -78,6 +114,13 @@ def get_input_directory() -> str:
|
|||||||
global input_directory
|
global input_directory
|
||||||
return input_directory
|
return input_directory
|
||||||
|
|
||||||
|
def get_user_directory() -> str:
|
||||||
|
return user_directory
|
||||||
|
|
||||||
|
def set_user_directory(user_dir: str) -> None:
|
||||||
|
global user_directory
|
||||||
|
user_directory = user_dir
|
||||||
|
|
||||||
|
|
||||||
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
||||||
def get_directory_by_type(type_name: str) -> str | None:
|
def get_directory_by_type(type_name: str) -> str | None:
|
||||||
@@ -89,6 +132,28 @@ def get_directory_by_type(type_name: str) -> str | None:
|
|||||||
return get_input_directory()
|
return get_input_directory()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Example:
|
||||||
|
files = os.listdir(folder_paths.get_input_directory())
|
||||||
|
filter_files_content_types(files, ["image", "audio", "video"])
|
||||||
|
"""
|
||||||
|
global extension_mimetypes_cache
|
||||||
|
result = []
|
||||||
|
for file in files:
|
||||||
|
extension = file.split('.')[-1]
|
||||||
|
if extension not in extension_mimetypes_cache:
|
||||||
|
mime_type, _ = mimetypes.guess_type(file, strict=False)
|
||||||
|
if not mime_type:
|
||||||
|
continue
|
||||||
|
content_type = mime_type.split('/')[0]
|
||||||
|
extension_mimetypes_cache[extension] = content_type
|
||||||
|
else:
|
||||||
|
content_type = extension_mimetypes_cache[extension]
|
||||||
|
|
||||||
|
if content_type in content_types:
|
||||||
|
result.append(file)
|
||||||
|
return result
|
||||||
|
|
||||||
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
||||||
# otherwise use default_path as base_dir
|
# otherwise use default_path as base_dir
|
||||||
@@ -130,11 +195,14 @@ def exists_annotated_filepath(name) -> bool:
|
|||||||
return os.path.exists(filepath)
|
return os.path.exists(filepath)
|
||||||
|
|
||||||
|
|
||||||
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
|
def add_model_folder_path(folder_name: str, full_folder_path: str, is_default: bool = False) -> None:
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
if folder_name in folder_names_and_paths:
|
if folder_name in folder_names_and_paths:
|
||||||
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
if is_default:
|
||||||
|
folder_names_and_paths[folder_name][0].insert(0, full_folder_path)
|
||||||
|
else:
|
||||||
|
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
||||||
else:
|
else:
|
||||||
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
||||||
|
|
||||||
@@ -166,8 +234,12 @@ def recursive_search(directory: str, excluded_dir_names: list[str] | None=None)
|
|||||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||||
for file_name in filenames:
|
for file_name in filenames:
|
||||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
try:
|
||||||
result.append(relative_path)
|
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||||
|
result.append(relative_path)
|
||||||
|
except:
|
||||||
|
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||||
|
continue
|
||||||
|
|
||||||
for d in subdirs:
|
for d in subdirs:
|
||||||
path: str = os.path.join(dirpath, d)
|
path: str = os.path.join(dirpath, d)
|
||||||
@@ -200,6 +272,14 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||||
|
full_path = get_full_path(folder_name, filename)
|
||||||
|
if full_path is None:
|
||||||
|
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
||||||
|
return full_path
|
||||||
|
|
||||||
|
|
||||||
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
@@ -214,6 +294,10 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
|
|||||||
return sorted(list(output_list)), output_folders, time.perf_counter()
|
return sorted(list(output_list)), output_folders, time.perf_counter()
|
||||||
|
|
||||||
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
||||||
|
strong_cache = cache_helper.get(folder_name)
|
||||||
|
if strong_cache is not None:
|
||||||
|
return strong_cache
|
||||||
|
|
||||||
global filename_list_cache
|
global filename_list_cache
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
@@ -242,6 +326,7 @@ def get_filename_list(folder_name: str) -> list[str]:
|
|||||||
out = get_filename_list_(folder_name)
|
out = get_filename_list_(folder_name)
|
||||||
global filename_list_cache
|
global filename_list_cache
|
||||||
filename_list_cache[folder_name] = out
|
filename_list_cache[folder_name] = out
|
||||||
|
cache_helper.set(folder_name, out)
|
||||||
return list(out[0])
|
return list(out[0])
|
||||||
|
|
||||||
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
|
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
|
||||||
@@ -257,9 +342,17 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
|||||||
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
||||||
input = input.replace("%width%", str(image_width))
|
input = input.replace("%width%", str(image_width))
|
||||||
input = input.replace("%height%", str(image_height))
|
input = input.replace("%height%", str(image_height))
|
||||||
|
now = time.localtime()
|
||||||
|
input = input.replace("%year%", str(now.tm_year))
|
||||||
|
input = input.replace("%month%", str(now.tm_mon).zfill(2))
|
||||||
|
input = input.replace("%day%", str(now.tm_mday).zfill(2))
|
||||||
|
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
|
||||||
|
input = input.replace("%minute%", str(now.tm_min).zfill(2))
|
||||||
|
input = input.replace("%second%", str(now.tm_sec).zfill(2))
|
||||||
return input
|
return input
|
||||||
|
|
||||||
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
if "%" in filename_prefix:
|
||||||
|
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||||
|
|
||||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import folder_paths
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
MAX_PREVIEW_RESOLUTION = 512
|
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||||
|
|
||||||
def preview_to_image(latent_image):
|
def preview_to_image(latent_image):
|
||||||
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
@@ -36,12 +36,20 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
|||||||
|
|
||||||
|
|
||||||
class Latent2RGBPreviewer(LatentPreviewer):
|
class Latent2RGBPreviewer(LatentPreviewer):
|
||||||
def __init__(self, latent_rgb_factors):
|
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
|
||||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
|
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
||||||
|
self.latent_rgb_factors_bias = None
|
||||||
|
if latent_rgb_factors_bias is not None:
|
||||||
|
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
||||||
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
if self.latent_rgb_factors_bias is not None:
|
||||||
|
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||||
|
|
||||||
|
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
||||||
|
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
||||||
|
|
||||||
return preview_to_image(latent_image)
|
return preview_to_image(latent_image)
|
||||||
|
|
||||||
|
|
||||||
@@ -71,7 +79,7 @@ def get_previewer(device, latent_format):
|
|||||||
|
|
||||||
if previewer is None:
|
if previewer is None:
|
||||||
if latent_format.latent_rgb_factors is not None:
|
if latent_format.latent_rgb_factors is not None:
|
||||||
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
|
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
|
||||||
return previewer
|
return previewer
|
||||||
|
|
||||||
def prepare_callback(model, steps, x0_output_dict=None):
|
def prepare_callback(model, steps, x0_output_dict=None):
|
||||||
|
|||||||
43
main.py
43
main.py
@@ -9,7 +9,7 @@ from comfy.cli_args import args
|
|||||||
from app.logger import setup_logger
|
from app.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
setup_logger(verbose=args.verbose)
|
setup_logger(log_level=args.verbose)
|
||||||
|
|
||||||
|
|
||||||
def execute_prestartup_script():
|
def execute_prestartup_script():
|
||||||
@@ -63,6 +63,7 @@ import threading
|
|||||||
import gc
|
import gc
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import utils.extra_config
|
||||||
|
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
@@ -85,7 +86,6 @@ if args.windows_standalone_build:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import yaml
|
|
||||||
|
|
||||||
import execution
|
import execution
|
||||||
import server
|
import server
|
||||||
@@ -160,7 +160,10 @@ def prompt_worker(q, server):
|
|||||||
need_gc = False
|
need_gc = False
|
||||||
|
|
||||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
addresses = []
|
||||||
|
for addr in address.split(","):
|
||||||
|
addresses.append((addr, port))
|
||||||
|
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
|
||||||
|
|
||||||
|
|
||||||
def hijack_progress(server):
|
def hijack_progress(server):
|
||||||
@@ -180,27 +183,6 @@ def cleanup_temp():
|
|||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
def load_extra_path_config(yaml_path):
|
|
||||||
with open(yaml_path, 'r') as stream:
|
|
||||||
config = yaml.safe_load(stream)
|
|
||||||
for c in config:
|
|
||||||
conf = config[c]
|
|
||||||
if conf is None:
|
|
||||||
continue
|
|
||||||
base_path = None
|
|
||||||
if "base_path" in conf:
|
|
||||||
base_path = conf.pop("base_path")
|
|
||||||
for x in conf:
|
|
||||||
for y in conf[x].split("\n"):
|
|
||||||
if len(y) == 0:
|
|
||||||
continue
|
|
||||||
full_path = y
|
|
||||||
if base_path is not None:
|
|
||||||
full_path = os.path.join(base_path, full_path)
|
|
||||||
logging.info("Adding extra search path {} {}".format(x, full_path))
|
|
||||||
folder_paths.add_model_folder_path(x, full_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if args.temp_directory:
|
if args.temp_directory:
|
||||||
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
||||||
@@ -222,11 +204,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||||
if os.path.isfile(extra_model_paths_config_path):
|
if os.path.isfile(extra_model_paths_config_path):
|
||||||
load_extra_path_config(extra_model_paths_config_path)
|
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
|
||||||
|
|
||||||
if args.extra_model_paths_config:
|
if args.extra_model_paths_config:
|
||||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||||
load_extra_path_config(config_path)
|
utils.extra_config.load_extra_path_config(config_path)
|
||||||
|
|
||||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
||||||
|
|
||||||
@@ -247,21 +229,30 @@ if __name__ == "__main__":
|
|||||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||||
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
||||||
|
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
|
||||||
|
|
||||||
if args.input_directory:
|
if args.input_directory:
|
||||||
input_dir = os.path.abspath(args.input_directory)
|
input_dir = os.path.abspath(args.input_directory)
|
||||||
logging.info(f"Setting input directory to: {input_dir}")
|
logging.info(f"Setting input directory to: {input_dir}")
|
||||||
folder_paths.set_input_directory(input_dir)
|
folder_paths.set_input_directory(input_dir)
|
||||||
|
|
||||||
|
if args.user_directory:
|
||||||
|
user_dir = os.path.abspath(args.user_directory)
|
||||||
|
logging.info(f"Setting user directory to: {user_dir}")
|
||||||
|
folder_paths.set_user_directory(user_dir)
|
||||||
|
|
||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
||||||
call_on_start = None
|
call_on_start = None
|
||||||
if args.auto_launch:
|
if args.auto_launch:
|
||||||
def startup_server(scheme, address, port):
|
def startup_server(scheme, address, port):
|
||||||
import webbrowser
|
import webbrowser
|
||||||
if os.name == 'nt' and address == '0.0.0.0':
|
if os.name == 'nt' and address == '0.0.0.0':
|
||||||
address = '127.0.0.1'
|
address = '127.0.0.1'
|
||||||
|
if ':' in address:
|
||||||
|
address = "[{}]".format(address)
|
||||||
webbrowser.open(f"{scheme}://{address}:{port}")
|
webbrowser.open(f"{scheme}://{address}:{port}")
|
||||||
call_on_start = startup_server
|
call_on_start = startup_server
|
||||||
|
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
# model_manager/__init__.py
|
# model_manager/__init__.py
|
||||||
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename
|
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
|
#NOTE: This was an experiment and WILL BE REMOVED
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
from folder_paths import models_dir
|
from folder_paths import folder_names_and_paths, get_folder_paths
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Any, Optional, Awaitable, Dict
|
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -17,6 +18,7 @@ class DownloadStatusType(Enum):
|
|||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DownloadModelStatus():
|
class DownloadModelStatus():
|
||||||
status: str
|
status: str
|
||||||
@@ -29,7 +31,7 @@ class DownloadModelStatus():
|
|||||||
self.progress_percentage = progress_percentage
|
self.progress_percentage = progress_percentage
|
||||||
self.message = message
|
self.message = message
|
||||||
self.already_existed = already_existed
|
self.already_existed = already_existed
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"status": self.status,
|
"status": self.status,
|
||||||
@@ -38,102 +40,112 @@ class DownloadModelStatus():
|
|||||||
"already_existed": self.already_existed
|
"already_existed": self.already_existed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_url: str,
|
model_url: str,
|
||||||
model_sub_directory: str,
|
model_directory: str,
|
||||||
|
folder_path: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
progress_interval: float = 1.0) -> DownloadModelStatus:
|
progress_interval: float = 1.0) -> DownloadModelStatus:
|
||||||
"""
|
"""
|
||||||
Download a model file from a given URL into the models directory.
|
Download a model file from a given URL into the models directory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
||||||
A function that makes an HTTP request. This makes it easier to mock in unit tests.
|
A function that makes an HTTP request. This makes it easier to mock in unit tests.
|
||||||
model_name (str):
|
model_name (str):
|
||||||
The name of the model file to be downloaded. This will be the filename on disk.
|
The name of the model file to be downloaded. This will be the filename on disk.
|
||||||
model_url (str):
|
model_url (str):
|
||||||
The URL from which to download the model.
|
The URL from which to download the model.
|
||||||
model_sub_directory (str):
|
model_directory (str):
|
||||||
The subdirectory within the main models directory where the model
|
The subdirectory within the main models directory where the model
|
||||||
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
||||||
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
||||||
An asynchronous function to call with progress updates.
|
An asynchronous function to call with progress updates.
|
||||||
|
folder_path (str);
|
||||||
|
Path to which model folder should be used as the root.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DownloadModelStatus: The result of the download operation.
|
DownloadModelStatus: The result of the download operation.
|
||||||
"""
|
"""
|
||||||
if not validate_model_subdirectory(model_sub_directory):
|
|
||||||
return DownloadModelStatus(
|
|
||||||
DownloadStatusType.ERROR,
|
|
||||||
0,
|
|
||||||
"Invalid model subdirectory",
|
|
||||||
False
|
|
||||||
)
|
|
||||||
|
|
||||||
if not validate_filename(model_name):
|
if not validate_filename(model_name):
|
||||||
return DownloadModelStatus(
|
return DownloadModelStatus(
|
||||||
DownloadStatusType.ERROR,
|
DownloadStatusType.ERROR,
|
||||||
0,
|
0,
|
||||||
"Invalid model name",
|
"Invalid model name",
|
||||||
False
|
False
|
||||||
)
|
)
|
||||||
|
|
||||||
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
|
if not model_directory in folder_names_and_paths:
|
||||||
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
if not folder_path in get_folder_paths(model_directory):
|
||||||
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path = create_model_path(model_name, folder_path)
|
||||||
|
existing_file = await check_file_exists(file_path, model_name, progress_callback)
|
||||||
if existing_file:
|
if existing_file:
|
||||||
return existing_file
|
return existing_file
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logging.info(f"Downloading {model_name} from {model_url}")
|
||||||
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
|
|
||||||
response = await model_download_request(model_url)
|
response = await model_download_request(model_url)
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||||
logging.error(error_message)
|
logging.error(error_message)
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
|
|
||||||
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
|
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in downloading model: {e}")
|
logging.error(f"Error in downloading model: {e}")
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback)
|
||||||
|
|
||||||
|
|
||||||
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
|
||||||
full_model_dir = os.path.join(models_base_dir, model_directory)
|
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
|
||||||
os.makedirs(full_model_dir, exist_ok=True)
|
os.makedirs(folder_path, exist_ok=True)
|
||||||
file_path = os.path.join(full_model_dir, model_name)
|
file_path = os.path.join(folder_path, model_name)
|
||||||
|
|
||||||
# Ensure the resulting path is still within the base directory
|
# Ensure the resulting path is still within the base directory
|
||||||
abs_file_path = os.path.abspath(file_path)
|
abs_file_path = os.path.abspath(file_path)
|
||||||
abs_base_dir = os.path.abspath(str(models_base_dir))
|
abs_base_dir = os.path.abspath(folder_path)
|
||||||
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
||||||
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
|
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
relative_path = '/'.join([model_directory, model_name])
|
async def check_file_exists(file_path: str,
|
||||||
return file_path, relative_path
|
model_name: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
|
||||||
async def check_file_exists(file_path: str,
|
) -> Optional[DownloadModelStatus]:
|
||||||
model_name: str,
|
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
|
||||||
relative_path: str) -> Optional[DownloadModelStatus]:
|
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return status
|
return status
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def track_download_progress(response: aiohttp.ClientResponse,
|
async def track_download_progress(response: aiohttp.ClientResponse,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
relative_path: str,
|
|
||||||
interval: float = 1.0) -> DownloadModelStatus:
|
interval: float = 1.0) -> DownloadModelStatus:
|
||||||
try:
|
try:
|
||||||
total_size = int(response.headers.get('Content-Length', 0))
|
total_size = int(response.headers.get('Content-Length', 0))
|
||||||
@@ -144,10 +156,11 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
nonlocal last_update_time
|
nonlocal last_update_time
|
||||||
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||||
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
last_update_time = time.time()
|
last_update_time = time.time()
|
||||||
|
|
||||||
with open(file_path, 'wb') as f:
|
temp_file_path = file_path + '.tmp'
|
||||||
|
with open(temp_file_path, 'wb') as f:
|
||||||
chunk_iterator = response.content.iter_chunked(8192)
|
chunk_iterator = response.content.iter_chunked(8192)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -156,58 +169,39 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
break
|
break
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
downloaded += len(chunk)
|
downloaded += len(chunk)
|
||||||
|
|
||||||
if time.time() - last_update_time >= interval:
|
if time.time() - last_update_time >= interval:
|
||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
|
os.rename(temp_file_path, file_path)
|
||||||
|
|
||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
|
|
||||||
return status
|
return status
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in track_download_progress: {e}")
|
logging.error(f"Error in track_download_progress: {e}")
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback)
|
||||||
|
|
||||||
async def handle_download_error(e: Exception,
|
|
||||||
model_name: str,
|
async def handle_download_error(e: Exception,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
model_name: str,
|
||||||
relative_path: str) -> DownloadModelStatus:
|
progress_callback: Callable[[str, DownloadModelStatus], Any]
|
||||||
|
) -> DownloadModelStatus:
|
||||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
|
||||||
"""
|
|
||||||
Validate that the model subdirectory is safe to install into.
|
|
||||||
Must not contain relative paths, nested paths or special characters
|
|
||||||
other than underscores and hyphens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_subdirectory (str): The subdirectory for the specific model type.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the subdirectory is safe, False otherwise.
|
|
||||||
"""
|
|
||||||
if len(model_subdirectory) > 50:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if '..' in model_subdirectory or '/' in model_subdirectory:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def validate_filename(filename: str)-> bool:
|
def validate_filename(filename: str)-> bool:
|
||||||
"""
|
"""
|
||||||
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename (str): The filename to validate
|
filename (str): The filename to validate
|
||||||
|
|
||||||
|
|||||||
54
nodes.py
54
nodes.py
@@ -511,10 +511,11 @@ class CheckpointLoader:
|
|||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
def load_checkpoint(self, config_name, ckpt_name):
|
def load_checkpoint(self, config_name, ckpt_name):
|
||||||
config_path = folder_paths.get_full_path("configs", config_name)
|
config_path = folder_paths.get_full_path("configs", config_name)
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
|
||||||
class CheckpointLoaderSimple:
|
class CheckpointLoaderSimple:
|
||||||
@@ -535,7 +536,7 @@ class CheckpointLoaderSimple:
|
|||||||
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name):
|
def load_checkpoint(self, ckpt_name):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out[:3]
|
return out[:3]
|
||||||
|
|
||||||
@@ -577,7 +578,7 @@ class unCLIPCheckpointLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -624,7 +625,7 @@ class LoraLoader:
|
|||||||
if strength_model == 0 and strength_clip == 0:
|
if strength_model == 0 and strength_clip == 0:
|
||||||
return (model, clip)
|
return (model, clip)
|
||||||
|
|
||||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
||||||
lora = None
|
lora = None
|
||||||
if self.loaded_lora is not None:
|
if self.loaded_lora is not None:
|
||||||
if self.loaded_lora[0] == lora_path:
|
if self.loaded_lora[0] == lora_path:
|
||||||
@@ -703,11 +704,11 @@ class VAELoader:
|
|||||||
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
||||||
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
||||||
|
|
||||||
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
|
enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
|
||||||
for k in enc:
|
for k in enc:
|
||||||
sd["taesd_encoder.{}".format(k)] = enc[k]
|
sd["taesd_encoder.{}".format(k)] = enc[k]
|
||||||
|
|
||||||
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
|
dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
|
||||||
for k in dec:
|
for k in dec:
|
||||||
sd["taesd_decoder.{}".format(k)] = dec[k]
|
sd["taesd_decoder.{}".format(k)] = dec[k]
|
||||||
|
|
||||||
@@ -738,7 +739,7 @@ class VAELoader:
|
|||||||
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||||
sd = self.load_taesd(vae_name)
|
sd = self.load_taesd(vae_name)
|
||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
sd = comfy.utils.load_torch_file(vae_path)
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
vae = comfy.sd.VAE(sd=sd)
|
vae = comfy.sd.VAE(sd=sd)
|
||||||
return (vae,)
|
return (vae,)
|
||||||
@@ -754,7 +755,7 @@ class ControlNetLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_controlnet(self, control_net_name):
|
def load_controlnet(self, control_net_name):
|
||||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
|
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
|
||||||
return (controlnet,)
|
return (controlnet,)
|
||||||
|
|
||||||
@@ -770,7 +771,7 @@ class DiffControlNetLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_controlnet(self, model, control_net_name):
|
def load_controlnet(self, model, control_net_name):
|
||||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
|
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
|
||||||
return (controlnet,)
|
return (controlnet,)
|
||||||
|
|
||||||
@@ -786,6 +787,7 @@ class ControlNetApply:
|
|||||||
RETURN_TYPES = ("CONDITIONING",)
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
FUNCTION = "apply_controlnet"
|
FUNCTION = "apply_controlnet"
|
||||||
|
|
||||||
|
DEPRECATED = True
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||||
@@ -815,7 +817,10 @@ class ControlNetApplyAdvanced:
|
|||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
}}
|
},
|
||||||
|
"optional": {"vae": ("VAE", ),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
||||||
RETURN_NAMES = ("positive", "negative")
|
RETURN_NAMES = ("positive", "negative")
|
||||||
@@ -823,7 +828,7 @@ class ControlNetApplyAdvanced:
|
|||||||
|
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
|
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
|
||||||
if strength == 0:
|
if strength == 0:
|
||||||
return (positive, negative)
|
return (positive, negative)
|
||||||
|
|
||||||
@@ -840,7 +845,7 @@ class ControlNetApplyAdvanced:
|
|||||||
if prev_cnet in cnets:
|
if prev_cnet in cnets:
|
||||||
c_net = cnets[prev_cnet]
|
c_net = cnets[prev_cnet]
|
||||||
else:
|
else:
|
||||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
|
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat)
|
||||||
c_net.set_previous_controlnet(prev_cnet)
|
c_net.set_previous_controlnet(prev_cnet)
|
||||||
cnets[prev_cnet] = c_net
|
cnets[prev_cnet] = c_net
|
||||||
|
|
||||||
@@ -856,7 +861,7 @@ class UNETLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "load_unet"
|
FUNCTION = "load_unet"
|
||||||
@@ -867,10 +872,13 @@ class UNETLoader:
|
|||||||
model_options = {}
|
model_options = {}
|
||||||
if weight_dtype == "fp8_e4m3fn":
|
if weight_dtype == "fp8_e4m3fn":
|
||||||
model_options["dtype"] = torch.float8_e4m3fn
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
|
elif weight_dtype == "fp8_e4m3fn_fast":
|
||||||
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
|
model_options["fp8_optimizations"] = True
|
||||||
elif weight_dtype == "fp8_e5m2":
|
elif weight_dtype == "fp8_e5m2":
|
||||||
model_options["dtype"] = torch.float8_e5m2
|
model_options["dtype"] = torch.float8_e5m2
|
||||||
|
|
||||||
unet_path = folder_paths.get_full_path("diffusion_models", unet_name)
|
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
|
||||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
@@ -895,7 +903,7 @@ class CLIPLoader:
|
|||||||
else:
|
else:
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
|
|
||||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
clip_path = folder_paths.get_full_path_or_raise("clip", clip_name)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
@@ -912,8 +920,8 @@ class DualCLIPLoader:
|
|||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, type):
|
def load_clip(self, clip_name1, clip_name2, type):
|
||||||
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
|
||||||
if type == "sdxl":
|
if type == "sdxl":
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
elif type == "sd3":
|
elif type == "sd3":
|
||||||
@@ -935,7 +943,7 @@ class CLIPVisionLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name):
|
def load_clip(self, clip_name):
|
||||||
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
|
clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
|
||||||
clip_vision = comfy.clip_vision.load(clip_path)
|
clip_vision = comfy.clip_vision.load(clip_path)
|
||||||
return (clip_vision,)
|
return (clip_vision,)
|
||||||
|
|
||||||
@@ -965,7 +973,7 @@ class StyleModelLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_style_model(self, style_model_name):
|
def load_style_model(self, style_model_name):
|
||||||
style_model_path = folder_paths.get_full_path("style_models", style_model_name)
|
style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
|
||||||
style_model = comfy.sd.load_style_model(style_model_path)
|
style_model = comfy.sd.load_style_model(style_model_path)
|
||||||
return (style_model,)
|
return (style_model,)
|
||||||
|
|
||||||
@@ -1030,7 +1038,7 @@ class GLIGENLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_gligen(self, gligen_name):
|
def load_gligen(self, gligen_name):
|
||||||
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
|
gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
|
||||||
gligen = comfy.sd.load_gligen(gligen_path)
|
gligen = comfy.sd.load_gligen(gligen_path)
|
||||||
return (gligen,)
|
return (gligen,)
|
||||||
|
|
||||||
@@ -1916,8 +1924,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ConditioningSetArea": "Conditioning (Set Area)",
|
"ConditioningSetArea": "Conditioning (Set Area)",
|
||||||
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
||||||
"ConditioningSetMask": "Conditioning (Set Mask)",
|
"ConditioningSetMask": "Conditioning (Set Mask)",
|
||||||
"ControlNetApply": "Apply ControlNet",
|
"ControlNetApply": "Apply ControlNet (OLD)",
|
||||||
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
|
"ControlNetApplyAdvanced": "Apply ControlNet",
|
||||||
# Latent
|
# Latent
|
||||||
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
||||||
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
||||||
@@ -2101,6 +2109,8 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_controlnet.py",
|
"nodes_controlnet.py",
|
||||||
"nodes_hunyuan.py",
|
"nodes_hunyuan.py",
|
||||||
"nodes_flux.py",
|
"nodes_flux.py",
|
||||||
|
"nodes_lora_extract.py",
|
||||||
|
"nodes_torch_compile.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@@ -38,6 +38,9 @@ def get_images(ws, prompt):
|
|||||||
if data['node'] is None and data['prompt_id'] == prompt_id:
|
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||||||
break #Execution is done
|
break #Execution is done
|
||||||
else:
|
else:
|
||||||
|
# If you want to be able to decode the binary stream for latent previews, here is how you can do it:
|
||||||
|
# bytesIO = BytesIO(out[8:])
|
||||||
|
# preview_image = Image.open(bytesIO) # This is your preview in PIL image format, store it in a global
|
||||||
continue #previews are binary data
|
continue #previews are binary data
|
||||||
|
|
||||||
history = get_history(prompt_id)[prompt_id]
|
history = get_history(prompt_id)[prompt_id]
|
||||||
@@ -151,7 +154,7 @@ prompt["3"]["inputs"]["seed"] = 5
|
|||||||
ws = websocket.WebSocket()
|
ws = websocket.WebSocket()
|
||||||
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
||||||
images = get_images(ws, prompt)
|
images = get_images(ws, prompt)
|
||||||
|
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
|
||||||
#Commented out code to display the output images:
|
#Commented out code to display the output images:
|
||||||
|
|
||||||
# for node_id in images:
|
# for node_id in images:
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ prompt["3"]["inputs"]["seed"] = 5
|
|||||||
ws = websocket.WebSocket()
|
ws = websocket.WebSocket()
|
||||||
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
||||||
images = get_images(ws, prompt)
|
images = get_images(ws, prompt)
|
||||||
|
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
|
||||||
#Commented out code to display the output images:
|
#Commented out code to display the output images:
|
||||||
|
|
||||||
# for node_id in images:
|
# for node_id in images:
|
||||||
|
|||||||
128
server.py
128
server.py
@@ -12,6 +12,8 @@ import json
|
|||||||
import glob
|
import glob
|
||||||
import struct
|
import struct
|
||||||
import ssl
|
import ssl
|
||||||
|
import socket
|
||||||
|
import ipaddress
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@@ -80,6 +82,68 @@ def create_cors_middleware(allowed_origin: str):
|
|||||||
|
|
||||||
return cors_middleware
|
return cors_middleware
|
||||||
|
|
||||||
|
def is_loopback(host):
|
||||||
|
if host is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
if ipaddress.ip_address(host).is_loopback:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
loopback = False
|
||||||
|
for family in (socket.AF_INET, socket.AF_INET6):
|
||||||
|
try:
|
||||||
|
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
|
||||||
|
for family, _, _, _, sockaddr in r:
|
||||||
|
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
|
||||||
|
return loopback
|
||||||
|
else:
|
||||||
|
loopback = True
|
||||||
|
except socket.gaierror:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return loopback
|
||||||
|
|
||||||
|
|
||||||
|
def create_origin_only_middleware():
|
||||||
|
@web.middleware
|
||||||
|
async def origin_only_middleware(request: web.Request, handler):
|
||||||
|
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
|
||||||
|
#in that case the Host and Origin hostnames won't match
|
||||||
|
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
|
||||||
|
if 'Host' in request.headers and 'Origin' in request.headers:
|
||||||
|
host = request.headers['Host']
|
||||||
|
origin = request.headers['Origin']
|
||||||
|
host_domain = host.lower()
|
||||||
|
parsed = urllib.parse.urlparse(origin)
|
||||||
|
origin_domain = parsed.netloc.lower()
|
||||||
|
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
|
||||||
|
|
||||||
|
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
|
||||||
|
loopback = is_loopback(host_domain_parsed.hostname)
|
||||||
|
|
||||||
|
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
|
||||||
|
host_domain = host_domain_parsed.hostname
|
||||||
|
if host_domain_parsed.port is None:
|
||||||
|
origin_domain = parsed.hostname
|
||||||
|
|
||||||
|
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
|
||||||
|
if host_domain != origin_domain:
|
||||||
|
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
response = web.Response()
|
||||||
|
else:
|
||||||
|
response = await handler(request)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
return origin_only_middleware
|
||||||
|
|
||||||
class PromptServer():
|
class PromptServer():
|
||||||
def __init__(self, loop):
|
def __init__(self, loop):
|
||||||
PromptServer.instance = self
|
PromptServer.instance = self
|
||||||
@@ -99,6 +163,8 @@ class PromptServer():
|
|||||||
middlewares = [cache_control]
|
middlewares = [cache_control]
|
||||||
if args.enable_cors_header:
|
if args.enable_cors_header:
|
||||||
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
||||||
|
else:
|
||||||
|
middlewares.append(create_origin_only_middleware())
|
||||||
|
|
||||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||||
@@ -155,6 +221,12 @@ class PromptServer():
|
|||||||
def get_embeddings(self):
|
def get_embeddings(self):
|
||||||
embeddings = folder_paths.get_filename_list("embeddings")
|
embeddings = folder_paths.get_filename_list("embeddings")
|
||||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||||
|
|
||||||
|
@routes.get("/models")
|
||||||
|
def list_model_types(request):
|
||||||
|
model_types = list(folder_paths.folder_names_and_paths.keys())
|
||||||
|
|
||||||
|
return web.json_response(model_types)
|
||||||
|
|
||||||
@routes.get("/models/{folder}")
|
@routes.get("/models/{folder}")
|
||||||
async def get_models(request):
|
async def get_models(request):
|
||||||
@@ -418,12 +490,17 @@ class PromptServer():
|
|||||||
async def system_stats(request):
|
async def system_stats(request):
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
device_name = comfy.model_management.get_torch_device_name(device)
|
device_name = comfy.model_management.get_torch_device_name(device)
|
||||||
|
cpu_device = comfy.model_management.torch.device("cpu")
|
||||||
|
ram_total = comfy.model_management.get_total_memory(cpu_device)
|
||||||
|
ram_free = comfy.model_management.get_free_memory(cpu_device)
|
||||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||||
|
|
||||||
system_stats = {
|
system_stats = {
|
||||||
"system": {
|
"system": {
|
||||||
"os": os.name,
|
"os": os.name,
|
||||||
|
"ram_total": ram_total,
|
||||||
|
"ram_free": ram_free,
|
||||||
"comfyui_version": get_comfyui_version(),
|
"comfyui_version": get_comfyui_version(),
|
||||||
"python_version": sys.version,
|
"python_version": sys.version,
|
||||||
"pytorch_version": comfy.model_management.torch_version,
|
"pytorch_version": comfy.model_management.torch_version,
|
||||||
@@ -480,14 +557,15 @@ class PromptServer():
|
|||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
async def get_object_info(request):
|
async def get_object_info(request):
|
||||||
out = {}
|
with folder_paths.cache_helper:
|
||||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
out = {}
|
||||||
try:
|
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||||
out[x] = node_info(x)
|
try:
|
||||||
except Exception as e:
|
out[x] = node_info(x)
|
||||||
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
except Exception as e:
|
||||||
logging.error(traceback.format_exc())
|
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
||||||
return web.json_response(out)
|
logging.error(traceback.format_exc())
|
||||||
|
return web.json_response(out)
|
||||||
|
|
||||||
@routes.get("/object_info/{node_class}")
|
@routes.get("/object_info/{node_class}")
|
||||||
async def get_object_info_node(request):
|
async def get_object_info_node(request):
|
||||||
@@ -601,6 +679,7 @@ class PromptServer():
|
|||||||
|
|
||||||
# Internal route. Should not be depended upon and is subject to change at any time.
|
# Internal route. Should not be depended upon and is subject to change at any time.
|
||||||
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
|
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
|
||||||
|
# NOTE: This was an experiment and WILL BE REMOVED
|
||||||
@routes.post("/internal/models/download")
|
@routes.post("/internal/models/download")
|
||||||
async def download_handler(request):
|
async def download_handler(request):
|
||||||
async def report_progress(filename: str, status: DownloadModelStatus):
|
async def report_progress(filename: str, status: DownloadModelStatus):
|
||||||
@@ -611,10 +690,11 @@ class PromptServer():
|
|||||||
data = await request.json()
|
data = await request.json()
|
||||||
url = data.get('url')
|
url = data.get('url')
|
||||||
model_directory = data.get('model_directory')
|
model_directory = data.get('model_directory')
|
||||||
|
folder_path = data.get('folder_path')
|
||||||
model_filename = data.get('model_filename')
|
model_filename = data.get('model_filename')
|
||||||
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
||||||
|
|
||||||
if not url or not model_directory or not model_filename:
|
if not url or not model_directory or not model_filename or not folder_path:
|
||||||
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||||
|
|
||||||
session = self.client_session
|
session = self.client_session
|
||||||
@@ -622,7 +702,7 @@ class PromptServer():
|
|||||||
logging.error("Client session is not initialized")
|
logging.error("Client session is not initialized")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
||||||
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
|
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return web.json_response(task.result().to_dict())
|
return web.json_response(task.result().to_dict())
|
||||||
@@ -739,6 +819,9 @@ class PromptServer():
|
|||||||
await self.send(*msg)
|
await self.send(*msg)
|
||||||
|
|
||||||
async def start(self, address, port, verbose=True, call_on_start=None):
|
async def start(self, address, port, verbose=True, call_on_start=None):
|
||||||
|
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
|
||||||
|
|
||||||
|
async def start_multi_address(self, addresses, call_on_start=None):
|
||||||
runner = web.AppRunner(self.app, access_log=None)
|
runner = web.AppRunner(self.app, access_log=None)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
ssl_ctx = None
|
ssl_ctx = None
|
||||||
@@ -749,17 +832,26 @@ class PromptServer():
|
|||||||
keyfile=args.tls_keyfile)
|
keyfile=args.tls_keyfile)
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
|
|
||||||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
logging.info("Starting server\n")
|
||||||
await site.start()
|
for addr in addresses:
|
||||||
|
address = addr[0]
|
||||||
|
port = addr[1]
|
||||||
|
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
self.address = address
|
if not hasattr(self, 'address'):
|
||||||
self.port = port
|
self.address = address #TODO: remove this
|
||||||
|
self.port = port
|
||||||
|
|
||||||
|
if ':' in address:
|
||||||
|
address_print = "[{}]".format(address)
|
||||||
|
else:
|
||||||
|
address_print = address
|
||||||
|
|
||||||
|
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
|
||||||
|
|
||||||
if verbose:
|
|
||||||
logging.info("Starting server\n")
|
|
||||||
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
|
|
||||||
if call_on_start is not None:
|
if call_on_start is not None:
|
||||||
call_on_start(scheme, address, port)
|
call_on_start(scheme, self.address, self.port)
|
||||||
|
|
||||||
def add_on_prompt_handler(self, handler):
|
def add_on_prompt_handler(self, handler):
|
||||||
self.on_prompt_handlers.append(handler)
|
self.on_prompt_handlers.append(handler)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
## Install test dependencies
|
## Install test dependencies
|
||||||
|
|
||||||
`pip install -r tests-units/requirements.txt`
|
`pip install -r tests-unit/requirements.txt`
|
||||||
|
|
||||||
## Run tests
|
## Run tests
|
||||||
`pytest tests-units/`
|
`pytest tests-unit/`
|
||||||
|
|||||||
66
tests-unit/comfy_test/folder_path_test.py
Normal file
66
tests-unit/comfy_test/folder_path_test.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
### 🗻 This file is created through the spirit of Mount Fuji at its peak
|
||||||
|
# TODO(yoland): clean up this after I get back down
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
yield tmpdirname
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_directory_by_type():
|
||||||
|
test_dir = "/test/dir"
|
||||||
|
folder_paths.set_output_directory(test_dir)
|
||||||
|
assert folder_paths.get_directory_by_type("output") == test_dir
|
||||||
|
assert folder_paths.get_directory_by_type("invalid") is None
|
||||||
|
|
||||||
|
def test_annotated_filepath():
|
||||||
|
assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None)
|
||||||
|
assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory())
|
||||||
|
assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory())
|
||||||
|
assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory())
|
||||||
|
|
||||||
|
def test_get_annotated_filepath():
|
||||||
|
default_dir = "/default/dir"
|
||||||
|
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
|
||||||
|
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
|
||||||
|
|
||||||
|
def test_add_model_folder_path():
|
||||||
|
folder_paths.add_model_folder_path("test_folder", "/test/path")
|
||||||
|
assert "/test/path" in folder_paths.get_folder_paths("test_folder")
|
||||||
|
|
||||||
|
def test_recursive_search(temp_dir):
|
||||||
|
os.makedirs(os.path.join(temp_dir, "subdir"))
|
||||||
|
open(os.path.join(temp_dir, "file1.txt"), "w").close()
|
||||||
|
open(os.path.join(temp_dir, "subdir", "file2.txt"), "w").close()
|
||||||
|
|
||||||
|
files, dirs = folder_paths.recursive_search(temp_dir)
|
||||||
|
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
|
||||||
|
assert len(dirs) == 2 # temp_dir and subdir
|
||||||
|
|
||||||
|
def test_filter_files_extensions():
|
||||||
|
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
|
||||||
|
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
|
||||||
|
assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"]
|
||||||
|
assert folder_paths.filter_files_extensions(files, []) == files
|
||||||
|
|
||||||
|
@patch("folder_paths.recursive_search")
|
||||||
|
@patch("folder_paths.folder_names_and_paths")
|
||||||
|
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
|
||||||
|
mock_folder_names_and_paths.__getitem__.return_value = (["/test/path"], {".txt"})
|
||||||
|
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
|
||||||
|
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
|
||||||
|
|
||||||
|
def test_get_save_image_path(temp_dir):
|
||||||
|
with patch("folder_paths.output_directory", temp_dir):
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
|
||||||
|
assert os.path.samefile(full_output_folder, temp_dir)
|
||||||
|
assert filename == "test"
|
||||||
|
assert counter == 1
|
||||||
|
assert subfolder == ""
|
||||||
|
assert filename_prefix == "test"
|
||||||
0
tests-unit/folder_paths_test/__init__.py
Normal file
0
tests-unit/folder_paths_test/__init__.py
Normal file
52
tests-unit/folder_paths_test/filter_by_content_types_test.py
Normal file
52
tests-unit/folder_paths_test/filter_by_content_types_test.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from folder_paths import filter_files_content_types
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def file_extensions():
|
||||||
|
return {
|
||||||
|
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
|
||||||
|
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
|
||||||
|
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mock_dir(file_extensions):
|
||||||
|
with tempfile.TemporaryDirectory() as directory:
|
||||||
|
for content_type, extensions in file_extensions.items():
|
||||||
|
for extension in extensions:
|
||||||
|
with open(f"{directory}/sample_{content_type}.{extension}", "w") as f:
|
||||||
|
f.write(f"Sample {content_type} file in {extension} format")
|
||||||
|
yield directory
|
||||||
|
|
||||||
|
|
||||||
|
def test_categorizes_all_correctly(mock_dir, file_extensions):
|
||||||
|
files = os.listdir(mock_dir)
|
||||||
|
for content_type, extensions in file_extensions.items():
|
||||||
|
filtered_files = filter_files_content_types(files, [content_type])
|
||||||
|
for extension in extensions:
|
||||||
|
assert f"sample_{content_type}.{extension}" in filtered_files
|
||||||
|
|
||||||
|
|
||||||
|
def test_categorizes_all_uniquely(mock_dir, file_extensions):
|
||||||
|
files = os.listdir(mock_dir)
|
||||||
|
for content_type, extensions in file_extensions.items():
|
||||||
|
filtered_files = filter_files_content_types(files, [content_type])
|
||||||
|
assert len(filtered_files) == len(extensions)
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_bad_extensions():
|
||||||
|
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
|
||||||
|
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_no_extension():
|
||||||
|
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
|
||||||
|
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_no_files():
|
||||||
|
files = []
|
||||||
|
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||||
@@ -1,10 +1,17 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import tempfile
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import ClientResponse
|
from aiohttp import ClientResponse
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
yield tmpdirname
|
||||||
|
|
||||||
class AsyncIteratorMock:
|
class AsyncIteratorMock:
|
||||||
"""
|
"""
|
||||||
@@ -42,7 +49,7 @@ class ContentMock:
|
|||||||
return AsyncIteratorMock(self.chunks)
|
return AsyncIteratorMock(self.chunks)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_success():
|
async def test_download_model_success(temp_dir):
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
mock_response.status = 200
|
mock_response.status = 200
|
||||||
mock_response.headers = {'Content-Length': '1000'}
|
mock_response.headers = {'Content-Length': '1000'}
|
||||||
@@ -53,15 +60,13 @@ async def test_download_model_success():
|
|||||||
mock_make_request = AsyncMock(return_value=mock_response)
|
mock_make_request = AsyncMock(return_value=mock_response)
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
# Mock file operations
|
|
||||||
mock_open = MagicMock()
|
|
||||||
mock_file = MagicMock()
|
|
||||||
mock_open.return_value.__enter__.return_value = mock_file
|
|
||||||
time_values = itertools.count(0, 0.1)
|
time_values = itertools.count(0, 0.1)
|
||||||
|
|
||||||
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
|
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
||||||
|
|
||||||
|
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
|
||||||
patch('model_filemanager.check_file_exists', return_value=None), \
|
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||||
patch('builtins.open', mock_open), \
|
patch('folder_paths.folder_names_and_paths', fake_paths), \
|
||||||
patch('time.time', side_effect=time_values): # Simulate time passing
|
patch('time.time', side_effect=time_values): # Simulate time passing
|
||||||
|
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
@@ -69,6 +74,7 @@ async def test_download_model_success():
|
|||||||
'model.sft',
|
'model.sft',
|
||||||
'http://example.com/model.sft',
|
'http://example.com/model.sft',
|
||||||
'checkpoints',
|
'checkpoints',
|
||||||
|
temp_dir,
|
||||||
mock_progress_callback
|
mock_progress_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -83,44 +89,48 @@ async def test_download_model_success():
|
|||||||
|
|
||||||
# Check initial call
|
# Check initial call
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'checkpoints/model.sft',
|
'model.sft',
|
||||||
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check final call
|
# Check final call
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'checkpoints/model.sft',
|
'model.sft',
|
||||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify file writing
|
mock_file_path = os.path.join(temp_dir, 'model.sft')
|
||||||
mock_file.write.assert_any_call(b'a' * 500)
|
assert os.path.exists(mock_file_path)
|
||||||
mock_file.write.assert_any_call(b'b' * 300)
|
with open(mock_file_path, 'rb') as mock_file:
|
||||||
mock_file.write.assert_any_call(b'c' * 200)
|
assert mock_file.read() == b''.join(chunks)
|
||||||
|
os.remove(mock_file_path)
|
||||||
|
|
||||||
# Verify request was made
|
# Verify request was made
|
||||||
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_url_request_failure():
|
async def test_download_model_url_request_failure(temp_dir):
|
||||||
# Mock dependencies
|
# Mock dependencies
|
||||||
mock_response = AsyncMock(spec=ClientResponse)
|
mock_response = AsyncMock(spec=ClientResponse)
|
||||||
mock_response.status = 404 # Simulate a "Not Found" error
|
mock_response.status = 404 # Simulate a "Not Found" error
|
||||||
mock_get = AsyncMock(return_value=mock_response)
|
mock_get = AsyncMock(return_value=mock_response)
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
||||||
|
|
||||||
# Mock the create_model_path function
|
# Mock the create_model_path function
|
||||||
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
|
with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
|
||||||
# Mock the check_file_exists function to return None (file doesn't exist)
|
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||||
with patch('model_filemanager.check_file_exists', return_value=None):
|
patch('folder_paths.folder_names_and_paths', fake_paths):
|
||||||
# Call the function
|
# Call the function
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
mock_get,
|
mock_get,
|
||||||
'model.safetensors',
|
'model.safetensors',
|
||||||
'http://example.com/model.safetensors',
|
'http://example.com/model.safetensors',
|
||||||
'mock_directory',
|
'checkpoints',
|
||||||
mock_progress_callback
|
temp_dir,
|
||||||
)
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
# Assert the expected behavior
|
# Assert the expected behavior
|
||||||
assert isinstance(result, DownloadModelStatus)
|
assert isinstance(result, DownloadModelStatus)
|
||||||
@@ -130,7 +140,7 @@ async def test_download_model_url_request_failure():
|
|||||||
|
|
||||||
# Check that progress_callback was called with the correct arguments
|
# Check that progress_callback was called with the correct arguments
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'mock_directory/model.safetensors',
|
'model.safetensors',
|
||||||
DownloadModelStatus(
|
DownloadModelStatus(
|
||||||
status=DownloadStatusType.PENDING,
|
status=DownloadStatusType.PENDING,
|
||||||
progress_percentage=0,
|
progress_percentage=0,
|
||||||
@@ -139,7 +149,7 @@ async def test_download_model_url_request_failure():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
mock_progress_callback.assert_called_with(
|
mock_progress_callback.assert_called_with(
|
||||||
'mock_directory/model.safetensors',
|
'model.safetensors',
|
||||||
DownloadModelStatus(
|
DownloadModelStatus(
|
||||||
status=DownloadStatusType.ERROR,
|
status=DownloadStatusType.ERROR,
|
||||||
progress_percentage=0,
|
progress_percentage=0,
|
||||||
@@ -153,98 +163,125 @@ async def test_download_model_url_request_failure():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_invalid_model_subdirectory():
|
async def test_download_model_invalid_model_subdirectory():
|
||||||
|
|
||||||
mock_make_request = AsyncMock()
|
mock_make_request = AsyncMock()
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
mock_make_request,
|
mock_make_request,
|
||||||
'model.sft',
|
'model.sft',
|
||||||
'http://example.com/model.sft',
|
'http://example.com/model.sft',
|
||||||
'../bad_path',
|
'../bad_path',
|
||||||
|
'../bad_path',
|
||||||
mock_progress_callback
|
mock_progress_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert the result
|
# Assert the result
|
||||||
assert isinstance(result, DownloadModelStatus)
|
assert isinstance(result, DownloadModelStatus)
|
||||||
assert result.message == 'Invalid model subdirectory'
|
assert result.message.startswith('Invalid or unrecognized model directory')
|
||||||
assert result.status == 'error'
|
assert result.status == 'error'
|
||||||
assert result.already_existed is False
|
assert result.already_existed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_model_invalid_folder_path():
|
||||||
|
mock_make_request = AsyncMock()
|
||||||
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
result = await download_model(
|
||||||
|
mock_make_request,
|
||||||
|
'model.sft',
|
||||||
|
'http://example.com/model.sft',
|
||||||
|
'checkpoints',
|
||||||
|
'invalid_path',
|
||||||
|
mock_progress_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the result
|
||||||
|
assert isinstance(result, DownloadModelStatus)
|
||||||
|
assert result.message.startswith("Invalid folder path")
|
||||||
|
assert result.status == 'error'
|
||||||
|
assert result.already_existed is False
|
||||||
|
|
||||||
# For create_model_path function
|
|
||||||
def test_create_model_path(tmp_path, monkeypatch):
|
def test_create_model_path(tmp_path, monkeypatch):
|
||||||
mock_models_dir = tmp_path / "models"
|
model_name = "model.safetensors"
|
||||||
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
|
folder_path = os.path.join(tmp_path, "mock_dir")
|
||||||
|
|
||||||
model_name = "test_model.sft"
|
file_path = create_model_path(model_name, folder_path)
|
||||||
model_directory = "test_dir"
|
|
||||||
|
assert file_path == os.path.join(folder_path, "model.safetensors")
|
||||||
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
|
|
||||||
|
|
||||||
assert file_path == str(mock_models_dir / model_directory / model_name)
|
|
||||||
assert relative_path == f"{model_directory}/{model_name}"
|
|
||||||
assert os.path.exists(os.path.dirname(file_path))
|
assert os.path.exists(os.path.dirname(file_path))
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Invalid model directory"):
|
||||||
|
create_model_path("../path_traversal.safetensors", folder_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Invalid model directory"):
|
||||||
|
create_model_path("/etc/some_root_path", folder_path)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_file_exists_when_file_exists(tmp_path):
|
async def test_check_file_exists_when_file_exists(tmp_path):
|
||||||
file_path = tmp_path / "existing_model.sft"
|
file_path = tmp_path / "existing_model.sft"
|
||||||
file_path.touch() # Create an empty file
|
file_path.touch() # Create an empty file
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
|
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
assert result.message == "existing_model.sft already exists"
|
assert result.message == "existing_model.sft already exists"
|
||||||
assert result.already_existed is True
|
assert result.already_existed is True
|
||||||
|
|
||||||
mock_callback.assert_called_once_with(
|
mock_callback.assert_called_once_with(
|
||||||
"test/existing_model.sft",
|
"existing_model.sft",
|
||||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
||||||
file_path = tmp_path / "non_existing_model.sft"
|
file_path = tmp_path / "non_existing_model.sft"
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
|
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback)
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
mock_callback.assert_not_called()
|
mock_callback.assert_not_called()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_track_download_progress_no_content_length():
|
async def test_track_download_progress_no_content_length(temp_dir):
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
mock_response.headers = {} # No Content-Length header
|
mock_response.headers = {} # No Content-Length header
|
||||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
|
chunks = [b'a' * 500, b'b' * 500]
|
||||||
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
mock_open = MagicMock(return_value=MagicMock())
|
|
||||||
|
|
||||||
with patch('builtins.open', mock_open):
|
full_path = os.path.join(temp_dir, 'model.sft')
|
||||||
result = await track_download_progress(
|
|
||||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
result = await track_download_progress(
|
||||||
mock_callback, 'models/model.sft', interval=0.1
|
mock_response, full_path, 'model.sft',
|
||||||
)
|
mock_callback, interval=0.1
|
||||||
|
)
|
||||||
|
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
|
|
||||||
|
assert os.path.exists(full_path)
|
||||||
|
with open(full_path, 'rb') as f:
|
||||||
|
assert f.read() == b''.join(chunks)
|
||||||
|
os.remove(full_path)
|
||||||
|
|
||||||
# Check that progress was reported even without knowing the total size
|
# Check that progress was reported even without knowing the total size
|
||||||
mock_callback.assert_any_call(
|
mock_callback.assert_any_call(
|
||||||
'models/model.sft',
|
'model.sft',
|
||||||
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_track_download_progress_interval():
|
async def test_track_download_progress_interval(temp_dir):
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
mock_response.headers = {'Content-Length': '1000'}
|
mock_response.headers = {'Content-Length': '1000'}
|
||||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
|
chunks = [b'a' * 100] * 10
|
||||||
|
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
mock_open = MagicMock(return_value=MagicMock())
|
mock_open = MagicMock(return_value=MagicMock())
|
||||||
@@ -253,18 +290,18 @@ async def test_track_download_progress_interval():
|
|||||||
mock_time = MagicMock()
|
mock_time = MagicMock()
|
||||||
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
||||||
|
|
||||||
with patch('builtins.open', mock_open), \
|
full_path = os.path.join(temp_dir, 'model.sft')
|
||||||
patch('time.time', mock_time):
|
|
||||||
await track_download_progress(
|
|
||||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
|
||||||
mock_callback, 'models/model.sft', interval=1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print out the actual call count and the arguments of each call for debugging
|
with patch('time.time', mock_time):
|
||||||
print(f"mock_callback was called {mock_callback.call_count} times")
|
await track_download_progress(
|
||||||
for i, call in enumerate(mock_callback.call_args_list):
|
mock_response, full_path, 'model.sft',
|
||||||
args, kwargs = call
|
mock_callback, interval=1.0
|
||||||
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
|
)
|
||||||
|
|
||||||
|
assert os.path.exists(full_path)
|
||||||
|
with open(full_path, 'rb') as f:
|
||||||
|
assert f.read() == b''.join(chunks)
|
||||||
|
os.remove(full_path)
|
||||||
|
|
||||||
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
||||||
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
||||||
@@ -279,27 +316,6 @@ async def test_track_download_progress_interval():
|
|||||||
assert last_call[0][1].status == "completed"
|
assert last_call[0][1].status == "completed"
|
||||||
assert last_call[0][1].progress_percentage == 100
|
assert last_call[0][1].progress_percentage == 100
|
||||||
|
|
||||||
def test_valid_subdirectory():
|
|
||||||
assert validate_model_subdirectory("valid-model123") is True
|
|
||||||
|
|
||||||
def test_subdirectory_too_long():
|
|
||||||
assert validate_model_subdirectory("a" * 51) is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_double_dots():
|
|
||||||
assert validate_model_subdirectory("model/../unsafe") is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_slash():
|
|
||||||
assert validate_model_subdirectory("model/unsafe") is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_special_characters():
|
|
||||||
assert validate_model_subdirectory("model@unsafe") is False
|
|
||||||
|
|
||||||
def test_subdirectory_with_underscore_and_dash():
|
|
||||||
assert validate_model_subdirectory("valid_model-name") is True
|
|
||||||
|
|
||||||
def test_empty_subdirectory():
|
|
||||||
assert validate_model_subdirectory("") is False
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("filename, expected", [
|
@pytest.mark.parametrize("filename, expected", [
|
||||||
("valid_model.safetensors", True),
|
("valid_model.safetensors", True),
|
||||||
("valid_model.sft", True),
|
("valid_model.sft", True),
|
||||||
|
|||||||
120
tests-unit/prompt_server_test/user_manager_test.py
Normal file
120
tests-unit/prompt_server_test/user_manager_test.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
from aiohttp import web
|
||||||
|
from app.user_manager import UserManager
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
pytestmark = (
|
||||||
|
pytest.mark.asyncio
|
||||||
|
) # This applies the asyncio mark to all test functions in the module
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_manager(tmp_path):
|
||||||
|
um = UserManager()
|
||||||
|
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
|
||||||
|
tmp_path, file
|
||||||
|
)
|
||||||
|
return um
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(user_manager):
|
||||||
|
app = web.Application()
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
user_manager.add_routes(routes)
|
||||||
|
app.add_routes(routes)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir")
|
||||||
|
assert resp.status == 404
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir")
|
||||||
|
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert await resp.json() == ["file1.txt"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&recurse=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir")
|
||||||
|
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&full_info=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["path"] == "file1.txt"
|
||||||
|
assert "size" in result[0]
|
||||||
|
assert "modified" in result[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert await resp.json() == [
|
||||||
|
["subdir/file1.txt", "subdir", "file1.txt"]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_invalid_directory(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=")
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
|
||||||
|
os_sep = "\\"
|
||||||
|
with patch("os.sep", os_sep):
|
||||||
|
with patch("os.path.sep", os_sep):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&recurse=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "/" in result[0] # Ensure forward slash is used
|
||||||
|
assert "\\" not in result[0] # Ensure backslash is not present
|
||||||
|
assert result[0] == "subdir/file1.txt"
|
||||||
|
|
||||||
|
# Test with full_info
|
||||||
|
resp = await client.get(
|
||||||
|
"/userdata?dir=test_dir&recurse=true&full_info=true"
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "/" in result[0]["path"] # Ensure forward slash is used
|
||||||
|
assert "\\" not in result[0]["path"] # Ensure backslash is not present
|
||||||
|
assert result[0]["path"] == "subdir/file1.txt"
|
||||||
126
tests-unit/utils/extra_config_test.py
Normal file
126
tests-unit/utils/extra_config_test.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from unittest.mock import Mock, patch, mock_open
|
||||||
|
|
||||||
|
from utils.extra_config import load_extra_path_config
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_yaml_content():
|
||||||
|
return {
|
||||||
|
'test_config': {
|
||||||
|
'base_path': '~/App/',
|
||||||
|
'checkpoints': 'subfolder1',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_expanded_home():
|
||||||
|
return '/home/user'
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def yaml_config_with_appdata():
|
||||||
|
return """
|
||||||
|
test_config:
|
||||||
|
base_path: '%APPDATA%/ComfyUI'
|
||||||
|
checkpoints: 'models/checkpoints'
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_yaml_content_appdata(yaml_config_with_appdata):
|
||||||
|
return yaml.safe_load(yaml_config_with_appdata)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_expandvars_appdata():
|
||||||
|
mock = Mock()
|
||||||
|
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
|
||||||
|
return mock
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_add_model_folder_path():
|
||||||
|
return Mock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_expanduser(mock_expanded_home):
|
||||||
|
def _expanduser(path):
|
||||||
|
if path.startswith('~/'):
|
||||||
|
return os.path.join(mock_expanded_home, path[2:])
|
||||||
|
return path
|
||||||
|
return _expanduser
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_yaml_safe_load(mock_yaml_content):
|
||||||
|
return Mock(return_value=mock_yaml_content)
|
||||||
|
|
||||||
|
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
|
||||||
|
def test_load_extra_model_paths_expands_userpath(
|
||||||
|
mock_file,
|
||||||
|
monkeypatch,
|
||||||
|
mock_add_model_folder_path,
|
||||||
|
mock_expanduser,
|
||||||
|
mock_yaml_safe_load,
|
||||||
|
mock_expanded_home
|
||||||
|
):
|
||||||
|
# Attach mocks used by load_extra_path_config
|
||||||
|
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
|
||||||
|
monkeypatch.setattr(os.path, 'expanduser', mock_expanduser)
|
||||||
|
monkeypatch.setattr(yaml, 'safe_load', mock_yaml_safe_load)
|
||||||
|
|
||||||
|
dummy_yaml_file_name = 'dummy_path.yaml'
|
||||||
|
load_extra_path_config(dummy_yaml_file_name)
|
||||||
|
|
||||||
|
expected_calls = [
|
||||||
|
('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1'), False),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||||
|
|
||||||
|
# Check if add_model_folder_path was called with the correct arguments
|
||||||
|
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||||
|
assert actual_call.args[0] == expected_call[0]
|
||||||
|
assert os.path.normpath(actual_call.args[1]) == os.path.normpath(expected_call[1]) # Normalize and check the path to check on multiple OS.
|
||||||
|
assert actual_call.args[2] == expected_call[2]
|
||||||
|
|
||||||
|
# Check if yaml.safe_load was called
|
||||||
|
mock_yaml_safe_load.assert_called_once()
|
||||||
|
|
||||||
|
# Check if open was called with the correct file path
|
||||||
|
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
|
||||||
|
|
||||||
|
@patch('builtins.open', new_callable=mock_open)
|
||||||
|
def test_load_extra_model_paths_expands_appdata(
|
||||||
|
mock_file,
|
||||||
|
monkeypatch,
|
||||||
|
mock_add_model_folder_path,
|
||||||
|
mock_expandvars_appdata,
|
||||||
|
yaml_config_with_appdata,
|
||||||
|
mock_yaml_content_appdata
|
||||||
|
):
|
||||||
|
# Set the mock_file to return yaml with appdata as a variable
|
||||||
|
mock_file.return_value.read.return_value = yaml_config_with_appdata
|
||||||
|
|
||||||
|
# Attach mocks
|
||||||
|
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
|
||||||
|
monkeypatch.setattr(os.path, 'expandvars', mock_expandvars_appdata)
|
||||||
|
monkeypatch.setattr(yaml, 'safe_load', Mock(return_value=mock_yaml_content_appdata))
|
||||||
|
|
||||||
|
# Mock expanduser to do nothing (since we're not testing it here)
|
||||||
|
monkeypatch.setattr(os.path, 'expanduser', lambda x: x)
|
||||||
|
|
||||||
|
dummy_yaml_file_name = 'dummy_path.yaml'
|
||||||
|
load_extra_path_config(dummy_yaml_file_name)
|
||||||
|
|
||||||
|
expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
|
||||||
|
expected_calls = [
|
||||||
|
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||||
|
|
||||||
|
# Check the base path variable was expanded
|
||||||
|
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||||
|
assert actual_call.args == expected_call
|
||||||
|
|
||||||
|
# Verify that expandvars was called
|
||||||
|
assert mock_expandvars_appdata.called
|
||||||
@@ -496,3 +496,29 @@ class TestExecution:
|
|||||||
assert len(images) == 1, "Should have 1 image"
|
assert len(images) == 1, "Should have 1 image"
|
||||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||||
assert not result.did_run(test_node), "The execution should have been cached"
|
assert not result.did_run(test_node), "The execution should have been cached"
|
||||||
|
|
||||||
|
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
|
||||||
|
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
|
||||||
|
# only that one entry in the list is blocked.
|
||||||
|
def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
image3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
image_list = g.node("TestMakeListNode", value1=image1.out(0), value2=image2.out(0), value3=image3.out(0))
|
||||||
|
int1 = g.node("StubInt", value=1)
|
||||||
|
int2 = g.node("StubInt", value=2)
|
||||||
|
int3 = g.node("StubInt", value=3)
|
||||||
|
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
|
||||||
|
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
|
||||||
|
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
||||||
|
|
||||||
|
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
|
||||||
|
output = g.node("PreviewImage", images=list_output.out(0))
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
assert result.did_run(output), "The execution should have run"
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 2, "Should have 2 images"
|
||||||
|
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
|
||||||
|
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
|
||||||
|
|||||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
28
utils/extra_config.py
Normal file
28
utils/extra_config.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import folder_paths
|
||||||
|
import logging
|
||||||
|
|
||||||
|
def load_extra_path_config(yaml_path):
|
||||||
|
with open(yaml_path, 'r') as stream:
|
||||||
|
config = yaml.safe_load(stream)
|
||||||
|
for c in config:
|
||||||
|
conf = config[c]
|
||||||
|
if conf is None:
|
||||||
|
continue
|
||||||
|
base_path = None
|
||||||
|
if "base_path" in conf:
|
||||||
|
base_path = conf.pop("base_path")
|
||||||
|
base_path = os.path.expandvars(os.path.expanduser(base_path))
|
||||||
|
is_default = False
|
||||||
|
if "is_default" in conf:
|
||||||
|
is_default = conf.pop("is_default")
|
||||||
|
for x in conf:
|
||||||
|
for y in conf[x].split("\n"):
|
||||||
|
if len(y) == 0:
|
||||||
|
continue
|
||||||
|
full_path = y
|
||||||
|
if base_path is not None:
|
||||||
|
full_path = os.path.join(base_path, full_path)
|
||||||
|
logging.info("Adding extra search path {} {}".format(x, full_path))
|
||||||
|
folder_paths.add_model_folder_path(x, full_path, is_default)
|
||||||
1
web/assets/CREDIT.txt
generated
vendored
Normal file
1
web/assets/CREDIT.txt
generated
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Thanks to OpenArt (https://openart.ai) for providing the sorted-custom-node-map data, captured in September 2024.
|
||||||
792
web/assets/GraphView-BGt8GmeB.css
generated
vendored
Normal file
792
web/assets/GraphView-BGt8GmeB.css
generated
vendored
Normal file
@@ -0,0 +1,792 @@
|
|||||||
|
|
||||||
|
.editable-text[data-v-54da6fc9] {
|
||||||
|
display: inline;
|
||||||
|
}
|
||||||
|
.editable-text input[data-v-54da6fc9] {
|
||||||
|
width: 100%;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
.group-title-editor.node-title-editor[data-v-fc3f26e3] {
|
||||||
|
z-index: 9999;
|
||||||
|
padding: 0.25rem;
|
||||||
|
}
|
||||||
|
[data-v-fc3f26e3] .editable-text {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
[data-v-fc3f26e3] .editable-text input {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
/* Override the default font size */
|
||||||
|
font-size: inherit;
|
||||||
|
}
|
||||||
|
|
||||||
|
.side-bar-button-icon {
|
||||||
|
font-size: var(--sidebar-icon-size) !important;
|
||||||
|
}
|
||||||
|
.side-bar-button-selected .side-bar-button-icon {
|
||||||
|
font-size: var(--sidebar-icon-size) !important;
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
|
||||||
|
.side-bar-button[data-v-caa3ee9c] {
|
||||||
|
width: var(--sidebar-width);
|
||||||
|
height: var(--sidebar-width);
|
||||||
|
border-radius: 0;
|
||||||
|
}
|
||||||
|
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
|
||||||
|
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
|
||||||
|
border-left: 4px solid var(--p-button-text-primary-color);
|
||||||
|
}
|
||||||
|
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
|
||||||
|
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
|
||||||
|
border-right: 4px solid var(--p-button-text-primary-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
:root {
|
||||||
|
--sidebar-width: 64px;
|
||||||
|
--sidebar-icon-size: 1.5rem;
|
||||||
|
}
|
||||||
|
:root .small-sidebar {
|
||||||
|
--sidebar-width: 40px;
|
||||||
|
--sidebar-icon-size: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.side-tool-bar-container[data-v-4da64512] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
|
||||||
|
pointer-events: auto;
|
||||||
|
|
||||||
|
width: var(--sidebar-width);
|
||||||
|
height: 100%;
|
||||||
|
|
||||||
|
background-color: var(--comfy-menu-bg);
|
||||||
|
color: var(--fg-color);
|
||||||
|
}
|
||||||
|
.side-tool-bar-end[data-v-4da64512] {
|
||||||
|
align-self: flex-end;
|
||||||
|
margin-top: auto;
|
||||||
|
}
|
||||||
|
.sidebar-content-container[data-v-4da64512] {
|
||||||
|
height: 100%;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.p-splitter-gutter {
|
||||||
|
pointer-events: auto;
|
||||||
|
}
|
||||||
|
.gutter-hidden {
|
||||||
|
display: none !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.side-bar-panel[data-v-b9df3042] {
|
||||||
|
background-color: var(--bg-color);
|
||||||
|
pointer-events: auto;
|
||||||
|
}
|
||||||
|
.splitter-overlay[data-v-b9df3042] {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
background-color: transparent;
|
||||||
|
pointer-events: none;
|
||||||
|
/* Set it the same as the ComfyUI menu */
|
||||||
|
/* Note: Lite-graph DOM widgets have the same z-index as the node id, so
|
||||||
|
999 should be sufficient to make sure splitter overlays on node's DOM
|
||||||
|
widgets */
|
||||||
|
z-index: 999;
|
||||||
|
border: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
._content[data-v-e7b35fd9] {
|
||||||
|
|
||||||
|
display: flex;
|
||||||
|
|
||||||
|
flex-direction: column
|
||||||
|
}
|
||||||
|
._content[data-v-e7b35fd9] > :not([hidden]) ~ :not([hidden]) {
|
||||||
|
|
||||||
|
--tw-space-y-reverse: 0;
|
||||||
|
|
||||||
|
margin-top: calc(0.5rem * calc(1 - var(--tw-space-y-reverse)));
|
||||||
|
|
||||||
|
margin-bottom: calc(0.5rem * var(--tw-space-y-reverse))
|
||||||
|
}
|
||||||
|
._footer[data-v-e7b35fd9] {
|
||||||
|
|
||||||
|
display: flex;
|
||||||
|
|
||||||
|
flex-direction: column;
|
||||||
|
|
||||||
|
align-items: flex-end;
|
||||||
|
|
||||||
|
padding-top: 1rem
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-37f672ab] .highlight {
|
||||||
|
background-color: var(--p-primary-color);
|
||||||
|
color: var(--p-primary-contrast-color);
|
||||||
|
font-weight: bold;
|
||||||
|
border-radius: 0.25rem;
|
||||||
|
padding: 0rem 0.125rem;
|
||||||
|
margin: -0.125rem 0.125rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.slot_row[data-v-ff07c900] {
|
||||||
|
padding: 2px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Original N-Sidebar styles */
|
||||||
|
._sb_dot[data-v-ff07c900] {
|
||||||
|
width: 8px;
|
||||||
|
height: 8px;
|
||||||
|
border-radius: 50%;
|
||||||
|
background-color: grey;
|
||||||
|
}
|
||||||
|
.node_header[data-v-ff07c900] {
|
||||||
|
line-height: 1;
|
||||||
|
padding: 8px 13px 7px;
|
||||||
|
margin-bottom: 5px;
|
||||||
|
font-size: 15px;
|
||||||
|
text-wrap: nowrap;
|
||||||
|
overflow: hidden;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
.headdot[data-v-ff07c900] {
|
||||||
|
width: 10px;
|
||||||
|
height: 10px;
|
||||||
|
float: inline-start;
|
||||||
|
margin-right: 8px;
|
||||||
|
}
|
||||||
|
.IMAGE[data-v-ff07c900] {
|
||||||
|
background-color: #64b5f6;
|
||||||
|
}
|
||||||
|
.VAE[data-v-ff07c900] {
|
||||||
|
background-color: #ff6e6e;
|
||||||
|
}
|
||||||
|
.LATENT[data-v-ff07c900] {
|
||||||
|
background-color: #ff9cf9;
|
||||||
|
}
|
||||||
|
.MASK[data-v-ff07c900] {
|
||||||
|
background-color: #81c784;
|
||||||
|
}
|
||||||
|
.CONDITIONING[data-v-ff07c900] {
|
||||||
|
background-color: #ffa931;
|
||||||
|
}
|
||||||
|
.CLIP[data-v-ff07c900] {
|
||||||
|
background-color: #ffd500;
|
||||||
|
}
|
||||||
|
.MODEL[data-v-ff07c900] {
|
||||||
|
background-color: #b39ddb;
|
||||||
|
}
|
||||||
|
.CONTROL_NET[data-v-ff07c900] {
|
||||||
|
background-color: #a5d6a7;
|
||||||
|
}
|
||||||
|
._sb_node_preview[data-v-ff07c900] {
|
||||||
|
background-color: var(--comfy-menu-bg);
|
||||||
|
font-family: 'Open Sans', sans-serif;
|
||||||
|
font-size: small;
|
||||||
|
color: var(--descrip-text);
|
||||||
|
border: 1px solid var(--descrip-text);
|
||||||
|
min-width: 300px;
|
||||||
|
width: -moz-min-content;
|
||||||
|
width: min-content;
|
||||||
|
height: -moz-fit-content;
|
||||||
|
height: fit-content;
|
||||||
|
z-index: 9999;
|
||||||
|
border-radius: 12px;
|
||||||
|
overflow: hidden;
|
||||||
|
font-size: 12px;
|
||||||
|
padding-bottom: 10px;
|
||||||
|
}
|
||||||
|
._sb_node_preview ._sb_description[data-v-ff07c900] {
|
||||||
|
margin: 10px;
|
||||||
|
padding: 6px;
|
||||||
|
background: var(--border-color);
|
||||||
|
border-radius: 5px;
|
||||||
|
font-style: italic;
|
||||||
|
font-weight: 500;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
word-break: break-word;
|
||||||
|
}
|
||||||
|
._sb_table[data-v-ff07c900] {
|
||||||
|
display: grid;
|
||||||
|
|
||||||
|
grid-column-gap: 10px;
|
||||||
|
/* Spazio tra le colonne */
|
||||||
|
width: 100%;
|
||||||
|
/* Imposta la larghezza della tabella al 100% del contenitore */
|
||||||
|
}
|
||||||
|
._sb_row[data-v-ff07c900] {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: 10px 1fr 1fr 1fr 10px;
|
||||||
|
grid-column-gap: 10px;
|
||||||
|
align-items: center;
|
||||||
|
padding-left: 9px;
|
||||||
|
padding-right: 9px;
|
||||||
|
}
|
||||||
|
._sb_row_string[data-v-ff07c900] {
|
||||||
|
grid-template-columns: 10px 1fr 1fr 10fr 1fr;
|
||||||
|
}
|
||||||
|
._sb_col[data-v-ff07c900] {
|
||||||
|
border: 0px solid #000;
|
||||||
|
display: flex;
|
||||||
|
align-items: flex-end;
|
||||||
|
flex-direction: row-reverse;
|
||||||
|
flex-wrap: nowrap;
|
||||||
|
align-content: flex-start;
|
||||||
|
justify-content: flex-end;
|
||||||
|
}
|
||||||
|
._sb_inherit[data-v-ff07c900] {
|
||||||
|
display: inherit;
|
||||||
|
}
|
||||||
|
._long_field[data-v-ff07c900] {
|
||||||
|
background: var(--bg-color);
|
||||||
|
border: 2px solid var(--border-color);
|
||||||
|
margin: 5px 5px 0 5px;
|
||||||
|
border-radius: 10px;
|
||||||
|
line-height: 1.7;
|
||||||
|
text-wrap: nowrap;
|
||||||
|
}
|
||||||
|
._sb_arrow[data-v-ff07c900] {
|
||||||
|
color: var(--fg-color);
|
||||||
|
}
|
||||||
|
._sb_preview_badge[data-v-ff07c900] {
|
||||||
|
text-align: center;
|
||||||
|
background: var(--comfy-input-bg);
|
||||||
|
font-weight: bold;
|
||||||
|
color: var(--error-text);
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfy-vue-node-search-container[data-v-2d409367] {
|
||||||
|
display: flex;
|
||||||
|
width: 100%;
|
||||||
|
min-width: 26rem;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
}
|
||||||
|
.comfy-vue-node-search-container[data-v-2d409367] * {
|
||||||
|
pointer-events: auto;
|
||||||
|
}
|
||||||
|
.comfy-vue-node-preview-container[data-v-2d409367] {
|
||||||
|
position: absolute;
|
||||||
|
left: -350px;
|
||||||
|
top: 50px;
|
||||||
|
}
|
||||||
|
.comfy-vue-node-search-box[data-v-2d409367] {
|
||||||
|
z-index: 10;
|
||||||
|
flex-grow: 1;
|
||||||
|
}
|
||||||
|
._filter-button[data-v-2d409367] {
|
||||||
|
z-index: 10;
|
||||||
|
}
|
||||||
|
._dialog[data-v-2d409367] {
|
||||||
|
min-width: 26rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.invisible-dialog-root {
|
||||||
|
width: 60%;
|
||||||
|
min-width: 24rem;
|
||||||
|
max-width: 48rem;
|
||||||
|
border: 0 !important;
|
||||||
|
background-color: transparent !important;
|
||||||
|
margin-top: 25vh;
|
||||||
|
margin-left: 400px;
|
||||||
|
}
|
||||||
|
@media all and (max-width: 768px) {
|
||||||
|
.invisible-dialog-root {
|
||||||
|
margin-left: 0px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.node-search-box-dialog-mask {
|
||||||
|
align-items: flex-start !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-tooltip[data-v-0a4402f9] {
|
||||||
|
background: var(--comfy-input-bg);
|
||||||
|
border-radius: 5px;
|
||||||
|
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
|
||||||
|
color: var(--input-text);
|
||||||
|
font-family: sans-serif;
|
||||||
|
left: 0;
|
||||||
|
max-width: 30vw;
|
||||||
|
padding: 4px 8px;
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
transform: translate(5px, calc(-100% - 5px));
|
||||||
|
white-space: pre-wrap;
|
||||||
|
z-index: 99999;
|
||||||
|
}
|
||||||
|
|
||||||
|
.p-buttongroup-vertical[data-v-ce8bd6ac] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
border-radius: var(--p-button-border-radius);
|
||||||
|
overflow: hidden;
|
||||||
|
border: 1px solid var(--p-panel-border-color);
|
||||||
|
}
|
||||||
|
.p-buttongroup-vertical .p-button[data-v-ce8bd6ac] {
|
||||||
|
margin: 0;
|
||||||
|
border-radius: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfy-image-wrap[data-v-9bc23daf] {
|
||||||
|
display: contents;
|
||||||
|
}
|
||||||
|
.comfy-image-blur[data-v-9bc23daf] {
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
-o-object-fit: cover;
|
||||||
|
object-fit: cover;
|
||||||
|
}
|
||||||
|
.comfy-image-main[data-v-9bc23daf] {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
-o-object-fit: cover;
|
||||||
|
object-fit: cover;
|
||||||
|
-o-object-position: center;
|
||||||
|
object-position: center;
|
||||||
|
z-index: 1;
|
||||||
|
}
|
||||||
|
.contain .comfy-image-wrap[data-v-9bc23daf] {
|
||||||
|
position: relative;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
.contain .comfy-image-main[data-v-9bc23daf] {
|
||||||
|
-o-object-fit: contain;
|
||||||
|
object-fit: contain;
|
||||||
|
-webkit-backdrop-filter: blur(10px);
|
||||||
|
backdrop-filter: blur(10px);
|
||||||
|
position: absolute;
|
||||||
|
}
|
||||||
|
.broken-image-placeholder[data-v-9bc23daf] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
margin: 2rem;
|
||||||
|
}
|
||||||
|
.broken-image-placeholder i[data-v-9bc23daf] {
|
||||||
|
font-size: 3rem;
|
||||||
|
margin-bottom: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.result-container[data-v-d9c060ae] {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
aspect-ratio: 1 / 1;
|
||||||
|
overflow: hidden;
|
||||||
|
position: relative;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
.image-preview-mask[data-v-d9c060ae] {
|
||||||
|
position: absolute;
|
||||||
|
left: 50%;
|
||||||
|
top: 50%;
|
||||||
|
transform: translate(-50%, -50%);
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
opacity: 0;
|
||||||
|
transition: opacity 0.3s ease;
|
||||||
|
z-index: 1;
|
||||||
|
}
|
||||||
|
.result-container:hover .image-preview-mask[data-v-d9c060ae] {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.task-result-preview[data-v-d4c8a1fe] {
|
||||||
|
aspect-ratio: 1 / 1;
|
||||||
|
overflow: hidden;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
.task-result-preview i[data-v-d4c8a1fe],
|
||||||
|
.task-result-preview span[data-v-d4c8a1fe] {
|
||||||
|
font-size: 2rem;
|
||||||
|
}
|
||||||
|
.task-item[data-v-d4c8a1fe] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
border-radius: 4px;
|
||||||
|
overflow: hidden;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
.task-item-details[data-v-d4c8a1fe] {
|
||||||
|
position: absolute;
|
||||||
|
bottom: 0;
|
||||||
|
padding: 0.6rem;
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
align-items: center;
|
||||||
|
width: 100%;
|
||||||
|
z-index: 1;
|
||||||
|
}
|
||||||
|
.task-node-link[data-v-d4c8a1fe] {
|
||||||
|
padding: 2px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* In dark mode, transparent background color for tags is not ideal for tags that
|
||||||
|
are floating on top of images. */
|
||||||
|
.tag-wrapper[data-v-d4c8a1fe] {
|
||||||
|
background-color: var(--p-primary-contrast-color);
|
||||||
|
border-radius: 6px;
|
||||||
|
display: inline-flex;
|
||||||
|
}
|
||||||
|
.node-name-tag[data-v-d4c8a1fe] {
|
||||||
|
word-break: break-all;
|
||||||
|
}
|
||||||
|
.status-tag-group[data-v-d4c8a1fe] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
}
|
||||||
|
.progress-preview-img[data-v-d4c8a1fe] {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
-o-object-fit: cover;
|
||||||
|
object-fit: cover;
|
||||||
|
-o-object-position: center;
|
||||||
|
object-position: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* PrimeVue's galleria teleports the fullscreen gallery out of subtree so we
|
||||||
|
cannot use scoped style here. */
|
||||||
|
img.galleria-image {
|
||||||
|
max-width: 100vw;
|
||||||
|
max-height: 100vh;
|
||||||
|
-o-object-fit: contain;
|
||||||
|
object-fit: contain;
|
||||||
|
}
|
||||||
|
.p-galleria-close-button {
|
||||||
|
/* Set z-index so the close button doesn't get hidden behind the image when image is large */
|
||||||
|
z-index: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfy-vue-side-bar-container[data-v-1b0a8fe3] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
height: 100%;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
.comfy-vue-side-bar-header[data-v-1b0a8fe3] {
|
||||||
|
flex-shrink: 0;
|
||||||
|
border-left: none;
|
||||||
|
border-right: none;
|
||||||
|
border-top: none;
|
||||||
|
border-radius: 0;
|
||||||
|
padding: 0.25rem 1rem;
|
||||||
|
min-height: 2.5rem;
|
||||||
|
}
|
||||||
|
.comfy-vue-side-bar-header-span[data-v-1b0a8fe3] {
|
||||||
|
font-size: small;
|
||||||
|
}
|
||||||
|
.comfy-vue-side-bar-body[data-v-1b0a8fe3] {
|
||||||
|
flex-grow: 1;
|
||||||
|
overflow: auto;
|
||||||
|
scrollbar-width: thin;
|
||||||
|
scrollbar-color: transparent transparent;
|
||||||
|
}
|
||||||
|
.comfy-vue-side-bar-body[data-v-1b0a8fe3]::-webkit-scrollbar {
|
||||||
|
width: 1px;
|
||||||
|
}
|
||||||
|
.comfy-vue-side-bar-body[data-v-1b0a8fe3]::-webkit-scrollbar-thumb {
|
||||||
|
background-color: transparent;
|
||||||
|
}
|
||||||
|
|
||||||
|
.scroll-container[data-v-08fa89b1] {
|
||||||
|
height: 100%;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
.queue-grid[data-v-08fa89b1] {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
|
||||||
|
padding: 0.5rem;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tree-node[data-v-633e27ab] {
|
||||||
|
width: 100%;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: space-between;
|
||||||
|
}
|
||||||
|
.leaf-count-badge[data-v-633e27ab] {
|
||||||
|
margin-left: 0.5rem;
|
||||||
|
}
|
||||||
|
.node-content[data-v-633e27ab] {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
flex-grow: 1;
|
||||||
|
}
|
||||||
|
.leaf-label[data-v-633e27ab] {
|
||||||
|
margin-left: 0.5rem;
|
||||||
|
}
|
||||||
|
[data-v-633e27ab] .editable-text span {
|
||||||
|
word-break: break-all;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-bd7bae90] .tree-explorer-node-label {
|
||||||
|
width: 100%;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
margin-left: var(--p-tree-node-gap);
|
||||||
|
flex-grow: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* The following styles are necessary to avoid layout shift when dragging nodes over folders.
|
||||||
|
* By setting the position to relative on the parent and using an absolutely positioned pseudo-element,
|
||||||
|
* we can create a visual indicator for the drop target without affecting the layout of other elements.
|
||||||
|
*/
|
||||||
|
[data-v-bd7bae90] .p-tree-node-content:has(.tree-folder) {
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
[data-v-bd7bae90] .p-tree-node-content:has(.tree-folder.can-drop)::after {
|
||||||
|
content: '';
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
right: 0;
|
||||||
|
bottom: 0;
|
||||||
|
border: 1px solid var(--p-content-color);
|
||||||
|
pointer-events: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-lib-node-container[data-v-90dfee08] {
|
||||||
|
height: 100%;
|
||||||
|
width: 100%
|
||||||
|
}
|
||||||
|
|
||||||
|
.p-selectbutton .p-button[data-v-91077f2a] {
|
||||||
|
padding: 0.5rem;
|
||||||
|
}
|
||||||
|
.p-selectbutton .p-button .pi[data-v-91077f2a] {
|
||||||
|
font-size: 1.5rem;
|
||||||
|
}
|
||||||
|
.field[data-v-91077f2a] {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
|
.color-picker-container[data-v-91077f2a] {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-lib-filter-popup {
|
||||||
|
margin-left: -13px;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-f6a7371a] .comfy-vue-side-bar-body {
|
||||||
|
background: var(--p-tree-background);
|
||||||
|
}
|
||||||
|
[data-v-f6a7371a] .node-lib-bookmark-tree-explorer {
|
||||||
|
padding-bottom: 2px;
|
||||||
|
}
|
||||||
|
[data-v-f6a7371a] .p-divider {
|
||||||
|
margin: var(--comfy-tree-explorer-item-padding) 0px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.model_preview[data-v-32e6c4d9] {
|
||||||
|
background-color: var(--comfy-menu-bg);
|
||||||
|
font-family: 'Open Sans', sans-serif;
|
||||||
|
color: var(--descrip-text);
|
||||||
|
border: 1px solid var(--descrip-text);
|
||||||
|
min-width: 300px;
|
||||||
|
max-width: 500px;
|
||||||
|
width: -moz-fit-content;
|
||||||
|
width: fit-content;
|
||||||
|
height: -moz-fit-content;
|
||||||
|
height: fit-content;
|
||||||
|
z-index: 9999;
|
||||||
|
border-radius: 12px;
|
||||||
|
overflow: hidden;
|
||||||
|
font-size: 12px;
|
||||||
|
padding: 10px;
|
||||||
|
}
|
||||||
|
.model_preview_image[data-v-32e6c4d9] {
|
||||||
|
margin: auto;
|
||||||
|
width: -moz-fit-content;
|
||||||
|
width: fit-content;
|
||||||
|
}
|
||||||
|
.model_preview_image img[data-v-32e6c4d9] {
|
||||||
|
max-width: 100%;
|
||||||
|
max-height: 150px;
|
||||||
|
-o-object-fit: contain;
|
||||||
|
object-fit: contain;
|
||||||
|
}
|
||||||
|
.model_preview_title[data-v-32e6c4d9] {
|
||||||
|
font-weight: bold;
|
||||||
|
text-align: center;
|
||||||
|
font-size: 14px;
|
||||||
|
}
|
||||||
|
.model_preview_top_container[data-v-32e6c4d9] {
|
||||||
|
text-align: center;
|
||||||
|
line-height: 0.5;
|
||||||
|
}
|
||||||
|
.model_preview_filename[data-v-32e6c4d9],
|
||||||
|
.model_preview_author[data-v-32e6c4d9],
|
||||||
|
.model_preview_architecture[data-v-32e6c4d9] {
|
||||||
|
display: inline-block;
|
||||||
|
text-align: center;
|
||||||
|
margin: 5px;
|
||||||
|
font-size: 10px;
|
||||||
|
}
|
||||||
|
.model_preview_prefix[data-v-32e6c4d9] {
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
|
||||||
|
.model-lib-model-icon-container[data-v-70b69131] {
|
||||||
|
display: inline-block;
|
||||||
|
position: relative;
|
||||||
|
left: 0;
|
||||||
|
height: 1.5rem;
|
||||||
|
vertical-align: top;
|
||||||
|
width: 0px;
|
||||||
|
}
|
||||||
|
.model-lib-model-icon[data-v-70b69131] {
|
||||||
|
background-size: cover;
|
||||||
|
background-position: center;
|
||||||
|
display: inline-block;
|
||||||
|
position: relative;
|
||||||
|
left: -2.5rem;
|
||||||
|
height: 2rem;
|
||||||
|
width: 2rem;
|
||||||
|
vertical-align: top;
|
||||||
|
}
|
||||||
|
|
||||||
|
.pi-fake-spacer {
|
||||||
|
height: 1px;
|
||||||
|
width: 16px;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-74b01bce] .comfy-vue-side-bar-body {
|
||||||
|
background: var(--p-tree-background);
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-d2d58252] .comfy-vue-side-bar-body {
|
||||||
|
background: var(--p-tree-background);
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-84e785b8] .p-togglebutton::before {
|
||||||
|
display: none
|
||||||
|
}
|
||||||
|
[data-v-84e785b8] .p-togglebutton {
|
||||||
|
position: relative;
|
||||||
|
flex-shrink: 0;
|
||||||
|
border-radius: 0px;
|
||||||
|
background-color: transparent;
|
||||||
|
padding-left: 0.5rem;
|
||||||
|
padding-right: 0.5rem
|
||||||
|
}
|
||||||
|
[data-v-84e785b8] .p-togglebutton.p-togglebutton-checked {
|
||||||
|
border-bottom-width: 2px;
|
||||||
|
border-bottom-color: var(--p-button-text-primary-color)
|
||||||
|
}
|
||||||
|
[data-v-84e785b8] .p-togglebutton-checked .close-button,[data-v-84e785b8] .p-togglebutton:hover .close-button {
|
||||||
|
visibility: visible
|
||||||
|
}
|
||||||
|
.status-indicator[data-v-84e785b8] {
|
||||||
|
position: absolute;
|
||||||
|
font-weight: 700;
|
||||||
|
font-size: 1.5rem;
|
||||||
|
top: 50%;
|
||||||
|
left: 50%;
|
||||||
|
transform: translate(-50%, -50%)
|
||||||
|
}
|
||||||
|
[data-v-84e785b8] .p-togglebutton:hover .status-indicator {
|
||||||
|
display: none
|
||||||
|
}
|
||||||
|
[data-v-84e785b8] .p-togglebutton .close-button {
|
||||||
|
visibility: hidden
|
||||||
|
}
|
||||||
|
|
||||||
|
.top-menubar[data-v-2ec1b620] .p-menubar-item-link svg {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
[data-v-2ec1b620] .p-menubar-submenu.dropdown-direction-up {
|
||||||
|
top: auto;
|
||||||
|
bottom: 100%;
|
||||||
|
flex-direction: column-reverse;
|
||||||
|
}
|
||||||
|
.keybinding-tag[data-v-2ec1b620] {
|
||||||
|
background: var(--p-content-hover-background);
|
||||||
|
border-color: var(--p-content-border-color);
|
||||||
|
border-style: solid;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-v-713442be] .p-inputtext {
|
||||||
|
border-top-left-radius: 0;
|
||||||
|
border-bottom-left-radius: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfyui-queue-button[data-v-fcd3efcd] .p-splitbutton-dropdown {
|
||||||
|
border-top-right-radius: 0;
|
||||||
|
border-bottom-right-radius: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.actionbar[data-v-bc6c78dd] {
|
||||||
|
pointer-events: all;
|
||||||
|
position: fixed;
|
||||||
|
z-index: 1000;
|
||||||
|
}
|
||||||
|
.actionbar.is-docked[data-v-bc6c78dd] {
|
||||||
|
position: static;
|
||||||
|
border-style: none;
|
||||||
|
background-color: transparent;
|
||||||
|
padding: 0px;
|
||||||
|
}
|
||||||
|
.actionbar.is-dragging[data-v-bc6c78dd] {
|
||||||
|
-webkit-user-select: none;
|
||||||
|
-moz-user-select: none;
|
||||||
|
user-select: none;
|
||||||
|
}
|
||||||
|
[data-v-bc6c78dd] .p-panel-content {
|
||||||
|
padding: 0.25rem;
|
||||||
|
}
|
||||||
|
[data-v-bc6c78dd] .p-panel-header {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfyui-menu[data-v-b13fdc92] {
|
||||||
|
width: 100vw;
|
||||||
|
background: var(--comfy-menu-bg);
|
||||||
|
color: var(--fg-color);
|
||||||
|
font-family: Arial, Helvetica, sans-serif;
|
||||||
|
font-size: 0.8em;
|
||||||
|
box-sizing: border-box;
|
||||||
|
z-index: 1000;
|
||||||
|
order: 0;
|
||||||
|
grid-column: 1/-1;
|
||||||
|
max-height: 90vh;
|
||||||
|
}
|
||||||
|
.comfyui-menu.dropzone[data-v-b13fdc92] {
|
||||||
|
background: var(--p-highlight-background);
|
||||||
|
}
|
||||||
|
.comfyui-menu.dropzone-active[data-v-b13fdc92] {
|
||||||
|
background: var(--p-highlight-background-focus);
|
||||||
|
}
|
||||||
|
.comfyui-logo[data-v-b13fdc92] {
|
||||||
|
font-size: 1.2em;
|
||||||
|
-webkit-user-select: none;
|
||||||
|
-moz-user-select: none;
|
||||||
|
user-select: none;
|
||||||
|
cursor: default;
|
||||||
|
}
|
||||||
17465
web/assets/GraphView-CVV2XJjS.js
generated
vendored
Normal file
17465
web/assets/GraphView-CVV2XJjS.js
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
web/assets/GraphView-CVV2XJjS.js.map
generated
vendored
Normal file
1
web/assets/GraphView-CVV2XJjS.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
865
web/assets/colorPalette-D5oi2-2V.js
generated
vendored
Normal file
865
web/assets/colorPalette-D5oi2-2V.js
generated
vendored
Normal file
@@ -0,0 +1,865 @@
|
|||||||
|
var __defProp = Object.defineProperty;
|
||||||
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
|
import { k as app, aP as LGraphCanvas, bO as useToastStore, ca as $el, z as LiteGraph } from "./index-DGAbdBYF.js";
|
||||||
|
const colorPalettes = {
|
||||||
|
dark: {
|
||||||
|
id: "dark",
|
||||||
|
name: "Dark (Default)",
|
||||||
|
colors: {
|
||||||
|
node_slot: {
|
||||||
|
CLIP: "#FFD500",
|
||||||
|
// bright yellow
|
||||||
|
CLIP_VISION: "#A8DADC",
|
||||||
|
// light blue-gray
|
||||||
|
CLIP_VISION_OUTPUT: "#ad7452",
|
||||||
|
// rusty brown-orange
|
||||||
|
CONDITIONING: "#FFA931",
|
||||||
|
// vibrant orange-yellow
|
||||||
|
CONTROL_NET: "#6EE7B7",
|
||||||
|
// soft mint green
|
||||||
|
IMAGE: "#64B5F6",
|
||||||
|
// bright sky blue
|
||||||
|
LATENT: "#FF9CF9",
|
||||||
|
// light pink-purple
|
||||||
|
MASK: "#81C784",
|
||||||
|
// muted green
|
||||||
|
MODEL: "#B39DDB",
|
||||||
|
// light lavender-purple
|
||||||
|
STYLE_MODEL: "#C2FFAE",
|
||||||
|
// light green-yellow
|
||||||
|
VAE: "#FF6E6E",
|
||||||
|
// bright red
|
||||||
|
NOISE: "#B0B0B0",
|
||||||
|
// gray
|
||||||
|
GUIDER: "#66FFFF",
|
||||||
|
// cyan
|
||||||
|
SAMPLER: "#ECB4B4",
|
||||||
|
// very soft red
|
||||||
|
SIGMAS: "#CDFFCD",
|
||||||
|
// soft lime green
|
||||||
|
TAESD: "#DCC274"
|
||||||
|
// cheesecake
|
||||||
|
},
|
||||||
|
litegraph_base: {
|
||||||
|
BACKGROUND_IMAGE: "",
|
||||||
|
CLEAR_BACKGROUND_COLOR: "#222",
|
||||||
|
NODE_TITLE_COLOR: "#999",
|
||||||
|
NODE_SELECTED_TITLE_COLOR: "#FFF",
|
||||||
|
NODE_TEXT_SIZE: 14,
|
||||||
|
NODE_TEXT_COLOR: "#AAA",
|
||||||
|
NODE_SUBTEXT_SIZE: 12,
|
||||||
|
NODE_DEFAULT_COLOR: "#333",
|
||||||
|
NODE_DEFAULT_BGCOLOR: "#353535",
|
||||||
|
NODE_DEFAULT_BOXCOLOR: "#666",
|
||||||
|
NODE_DEFAULT_SHAPE: "box",
|
||||||
|
NODE_BOX_OUTLINE_COLOR: "#FFF",
|
||||||
|
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||||
|
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||||
|
DEFAULT_GROUP_FONT: 24,
|
||||||
|
WIDGET_BGCOLOR: "#222",
|
||||||
|
WIDGET_OUTLINE_COLOR: "#666",
|
||||||
|
WIDGET_TEXT_COLOR: "#DDD",
|
||||||
|
WIDGET_SECONDARY_TEXT_COLOR: "#999",
|
||||||
|
LINK_COLOR: "#9A9",
|
||||||
|
EVENT_LINK_COLOR: "#A86",
|
||||||
|
CONNECTING_LINK_COLOR: "#AFA",
|
||||||
|
BADGE_FG_COLOR: "#FFF",
|
||||||
|
BADGE_BG_COLOR: "#0F1F0F"
|
||||||
|
},
|
||||||
|
comfy_base: {
|
||||||
|
"fg-color": "#fff",
|
||||||
|
"bg-color": "#202020",
|
||||||
|
"comfy-menu-bg": "#353535",
|
||||||
|
"comfy-input-bg": "#222",
|
||||||
|
"input-text": "#ddd",
|
||||||
|
"descrip-text": "#999",
|
||||||
|
"drag-text": "#ccc",
|
||||||
|
"error-text": "#ff4444",
|
||||||
|
"border-color": "#4e4e4e",
|
||||||
|
"tr-even-bg-color": "#222",
|
||||||
|
"tr-odd-bg-color": "#353535",
|
||||||
|
"content-bg": "#4e4e4e",
|
||||||
|
"content-fg": "#fff",
|
||||||
|
"content-hover-bg": "#222",
|
||||||
|
"content-hover-fg": "#fff"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
light: {
|
||||||
|
id: "light",
|
||||||
|
name: "Light",
|
||||||
|
colors: {
|
||||||
|
node_slot: {
|
||||||
|
CLIP: "#FFA726",
|
||||||
|
// orange
|
||||||
|
CLIP_VISION: "#5C6BC0",
|
||||||
|
// indigo
|
||||||
|
CLIP_VISION_OUTPUT: "#8D6E63",
|
||||||
|
// brown
|
||||||
|
CONDITIONING: "#EF5350",
|
||||||
|
// red
|
||||||
|
CONTROL_NET: "#66BB6A",
|
||||||
|
// green
|
||||||
|
IMAGE: "#42A5F5",
|
||||||
|
// blue
|
||||||
|
LATENT: "#AB47BC",
|
||||||
|
// purple
|
||||||
|
MASK: "#9CCC65",
|
||||||
|
// light green
|
||||||
|
MODEL: "#7E57C2",
|
||||||
|
// deep purple
|
||||||
|
STYLE_MODEL: "#D4E157",
|
||||||
|
// lime
|
||||||
|
VAE: "#FF7043"
|
||||||
|
// deep orange
|
||||||
|
},
|
||||||
|
litegraph_base: {
|
||||||
|
BACKGROUND_IMAGE: "",
|
||||||
|
CLEAR_BACKGROUND_COLOR: "lightgray",
|
||||||
|
NODE_TITLE_COLOR: "#222",
|
||||||
|
NODE_SELECTED_TITLE_COLOR: "#000",
|
||||||
|
NODE_TEXT_SIZE: 14,
|
||||||
|
NODE_TEXT_COLOR: "#444",
|
||||||
|
NODE_SUBTEXT_SIZE: 12,
|
||||||
|
NODE_DEFAULT_COLOR: "#F7F7F7",
|
||||||
|
NODE_DEFAULT_BGCOLOR: "#F5F5F5",
|
||||||
|
NODE_DEFAULT_BOXCOLOR: "#CCC",
|
||||||
|
NODE_DEFAULT_SHAPE: "box",
|
||||||
|
NODE_BOX_OUTLINE_COLOR: "#000",
|
||||||
|
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||||
|
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.1)",
|
||||||
|
DEFAULT_GROUP_FONT: 24,
|
||||||
|
WIDGET_BGCOLOR: "#D4D4D4",
|
||||||
|
WIDGET_OUTLINE_COLOR: "#999",
|
||||||
|
WIDGET_TEXT_COLOR: "#222",
|
||||||
|
WIDGET_SECONDARY_TEXT_COLOR: "#555",
|
||||||
|
LINK_COLOR: "#4CAF50",
|
||||||
|
EVENT_LINK_COLOR: "#FF9800",
|
||||||
|
CONNECTING_LINK_COLOR: "#2196F3",
|
||||||
|
BADGE_FG_COLOR: "#000",
|
||||||
|
BADGE_BG_COLOR: "#FFF"
|
||||||
|
},
|
||||||
|
comfy_base: {
|
||||||
|
"fg-color": "#222",
|
||||||
|
"bg-color": "#DDD",
|
||||||
|
"comfy-menu-bg": "#F5F5F5",
|
||||||
|
"comfy-input-bg": "#C9C9C9",
|
||||||
|
"input-text": "#222",
|
||||||
|
"descrip-text": "#444",
|
||||||
|
"drag-text": "#555",
|
||||||
|
"error-text": "#F44336",
|
||||||
|
"border-color": "#888",
|
||||||
|
"tr-even-bg-color": "#f9f9f9",
|
||||||
|
"tr-odd-bg-color": "#fff",
|
||||||
|
"content-bg": "#e0e0e0",
|
||||||
|
"content-fg": "#222",
|
||||||
|
"content-hover-bg": "#adadad",
|
||||||
|
"content-hover-fg": "#222"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
solarized: {
|
||||||
|
id: "solarized",
|
||||||
|
name: "Solarized",
|
||||||
|
colors: {
|
||||||
|
node_slot: {
|
||||||
|
CLIP: "#2AB7CA",
|
||||||
|
// light blue
|
||||||
|
CLIP_VISION: "#6c71c4",
|
||||||
|
// blue violet
|
||||||
|
CLIP_VISION_OUTPUT: "#859900",
|
||||||
|
// olive green
|
||||||
|
CONDITIONING: "#d33682",
|
||||||
|
// magenta
|
||||||
|
CONTROL_NET: "#d1ffd7",
|
||||||
|
// light mint green
|
||||||
|
IMAGE: "#5940bb",
|
||||||
|
// deep blue violet
|
||||||
|
LATENT: "#268bd2",
|
||||||
|
// blue
|
||||||
|
MASK: "#CCC9E7",
|
||||||
|
// light purple-gray
|
||||||
|
MODEL: "#dc322f",
|
||||||
|
// red
|
||||||
|
STYLE_MODEL: "#1a998a",
|
||||||
|
// teal
|
||||||
|
UPSCALE_MODEL: "#054A29",
|
||||||
|
// dark green
|
||||||
|
VAE: "#facfad"
|
||||||
|
// light pink-orange
|
||||||
|
},
|
||||||
|
litegraph_base: {
|
||||||
|
NODE_TITLE_COLOR: "#fdf6e3",
|
||||||
|
// Base3
|
||||||
|
NODE_SELECTED_TITLE_COLOR: "#A9D400",
|
||||||
|
NODE_TEXT_SIZE: 14,
|
||||||
|
NODE_TEXT_COLOR: "#657b83",
|
||||||
|
// Base00
|
||||||
|
NODE_SUBTEXT_SIZE: 12,
|
||||||
|
NODE_DEFAULT_COLOR: "#094656",
|
||||||
|
NODE_DEFAULT_BGCOLOR: "#073642",
|
||||||
|
// Base02
|
||||||
|
NODE_DEFAULT_BOXCOLOR: "#839496",
|
||||||
|
// Base0
|
||||||
|
NODE_DEFAULT_SHAPE: "box",
|
||||||
|
NODE_BOX_OUTLINE_COLOR: "#fdf6e3",
|
||||||
|
// Base3
|
||||||
|
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||||
|
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||||
|
DEFAULT_GROUP_FONT: 24,
|
||||||
|
WIDGET_BGCOLOR: "#002b36",
|
||||||
|
// Base03
|
||||||
|
WIDGET_OUTLINE_COLOR: "#839496",
|
||||||
|
// Base0
|
||||||
|
WIDGET_TEXT_COLOR: "#fdf6e3",
|
||||||
|
// Base3
|
||||||
|
WIDGET_SECONDARY_TEXT_COLOR: "#93a1a1",
|
||||||
|
// Base1
|
||||||
|
LINK_COLOR: "#2aa198",
|
||||||
|
// Solarized Cyan
|
||||||
|
EVENT_LINK_COLOR: "#268bd2",
|
||||||
|
// Solarized Blue
|
||||||
|
CONNECTING_LINK_COLOR: "#859900"
|
||||||
|
// Solarized Green
|
||||||
|
},
|
||||||
|
comfy_base: {
|
||||||
|
"fg-color": "#fdf6e3",
|
||||||
|
// Base3
|
||||||
|
"bg-color": "#002b36",
|
||||||
|
// Base03
|
||||||
|
"comfy-menu-bg": "#073642",
|
||||||
|
// Base02
|
||||||
|
"comfy-input-bg": "#002b36",
|
||||||
|
// Base03
|
||||||
|
"input-text": "#93a1a1",
|
||||||
|
// Base1
|
||||||
|
"descrip-text": "#586e75",
|
||||||
|
// Base01
|
||||||
|
"drag-text": "#839496",
|
||||||
|
// Base0
|
||||||
|
"error-text": "#dc322f",
|
||||||
|
// Solarized Red
|
||||||
|
"border-color": "#657b83",
|
||||||
|
// Base00
|
||||||
|
"tr-even-bg-color": "#002b36",
|
||||||
|
"tr-odd-bg-color": "#073642",
|
||||||
|
"content-bg": "#657b83",
|
||||||
|
"content-fg": "#fdf6e3",
|
||||||
|
"content-hover-bg": "#002b36",
|
||||||
|
"content-hover-fg": "#fdf6e3"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
arc: {
|
||||||
|
id: "arc",
|
||||||
|
name: "Arc",
|
||||||
|
colors: {
|
||||||
|
node_slot: {
|
||||||
|
BOOLEAN: "",
|
||||||
|
CLIP: "#eacb8b",
|
||||||
|
CLIP_VISION: "#A8DADC",
|
||||||
|
CLIP_VISION_OUTPUT: "#ad7452",
|
||||||
|
CONDITIONING: "#cf876f",
|
||||||
|
CONTROL_NET: "#00d78d",
|
||||||
|
CONTROL_NET_WEIGHTS: "",
|
||||||
|
FLOAT: "",
|
||||||
|
GLIGEN: "",
|
||||||
|
IMAGE: "#80a1c0",
|
||||||
|
IMAGEUPLOAD: "",
|
||||||
|
INT: "",
|
||||||
|
LATENT: "#b38ead",
|
||||||
|
LATENT_KEYFRAME: "",
|
||||||
|
MASK: "#a3bd8d",
|
||||||
|
MODEL: "#8978a7",
|
||||||
|
SAMPLER: "",
|
||||||
|
SIGMAS: "",
|
||||||
|
STRING: "",
|
||||||
|
STYLE_MODEL: "#C2FFAE",
|
||||||
|
T2I_ADAPTER_WEIGHTS: "",
|
||||||
|
TAESD: "#DCC274",
|
||||||
|
TIMESTEP_KEYFRAME: "",
|
||||||
|
UPSCALE_MODEL: "",
|
||||||
|
VAE: "#be616b"
|
||||||
|
},
|
||||||
|
litegraph_base: {
|
||||||
|
BACKGROUND_IMAGE: "",
|
||||||
|
CLEAR_BACKGROUND_COLOR: "#2b2f38",
|
||||||
|
NODE_TITLE_COLOR: "#b2b7bd",
|
||||||
|
NODE_SELECTED_TITLE_COLOR: "#FFF",
|
||||||
|
NODE_TEXT_SIZE: 14,
|
||||||
|
NODE_TEXT_COLOR: "#AAA",
|
||||||
|
NODE_SUBTEXT_SIZE: 12,
|
||||||
|
NODE_DEFAULT_COLOR: "#2b2f38",
|
||||||
|
NODE_DEFAULT_BGCOLOR: "#242730",
|
||||||
|
NODE_DEFAULT_BOXCOLOR: "#6e7581",
|
||||||
|
NODE_DEFAULT_SHAPE: "box",
|
||||||
|
NODE_BOX_OUTLINE_COLOR: "#FFF",
|
||||||
|
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||||
|
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||||
|
DEFAULT_GROUP_FONT: 22,
|
||||||
|
WIDGET_BGCOLOR: "#2b2f38",
|
||||||
|
WIDGET_OUTLINE_COLOR: "#6e7581",
|
||||||
|
WIDGET_TEXT_COLOR: "#DDD",
|
||||||
|
WIDGET_SECONDARY_TEXT_COLOR: "#b2b7bd",
|
||||||
|
LINK_COLOR: "#9A9",
|
||||||
|
EVENT_LINK_COLOR: "#A86",
|
||||||
|
CONNECTING_LINK_COLOR: "#AFA"
|
||||||
|
},
|
||||||
|
comfy_base: {
|
||||||
|
"fg-color": "#fff",
|
||||||
|
"bg-color": "#2b2f38",
|
||||||
|
"comfy-menu-bg": "#242730",
|
||||||
|
"comfy-input-bg": "#2b2f38",
|
||||||
|
"input-text": "#ddd",
|
||||||
|
"descrip-text": "#b2b7bd",
|
||||||
|
"drag-text": "#ccc",
|
||||||
|
"error-text": "#ff4444",
|
||||||
|
"border-color": "#6e7581",
|
||||||
|
"tr-even-bg-color": "#2b2f38",
|
||||||
|
"tr-odd-bg-color": "#242730",
|
||||||
|
"content-bg": "#6e7581",
|
||||||
|
"content-fg": "#fff",
|
||||||
|
"content-hover-bg": "#2b2f38",
|
||||||
|
"content-hover-fg": "#fff"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
nord: {
|
||||||
|
id: "nord",
|
||||||
|
name: "Nord",
|
||||||
|
colors: {
|
||||||
|
node_slot: {
|
||||||
|
BOOLEAN: "",
|
||||||
|
CLIP: "#eacb8b",
|
||||||
|
CLIP_VISION: "#A8DADC",
|
||||||
|
CLIP_VISION_OUTPUT: "#ad7452",
|
||||||
|
CONDITIONING: "#cf876f",
|
||||||
|
CONTROL_NET: "#00d78d",
|
||||||
|
CONTROL_NET_WEIGHTS: "",
|
||||||
|
FLOAT: "",
|
||||||
|
GLIGEN: "",
|
||||||
|
IMAGE: "#80a1c0",
|
||||||
|
IMAGEUPLOAD: "",
|
||||||
|
INT: "",
|
||||||
|
LATENT: "#b38ead",
|
||||||
|
LATENT_KEYFRAME: "",
|
||||||
|
MASK: "#a3bd8d",
|
||||||
|
MODEL: "#8978a7",
|
||||||
|
SAMPLER: "",
|
||||||
|
SIGMAS: "",
|
||||||
|
STRING: "",
|
||||||
|
STYLE_MODEL: "#C2FFAE",
|
||||||
|
T2I_ADAPTER_WEIGHTS: "",
|
||||||
|
TAESD: "#DCC274",
|
||||||
|
TIMESTEP_KEYFRAME: "",
|
||||||
|
UPSCALE_MODEL: "",
|
||||||
|
VAE: "#be616b"
|
||||||
|
},
|
||||||
|
litegraph_base: {
|
||||||
|
BACKGROUND_IMAGE: "",
|
||||||
|
CLEAR_BACKGROUND_COLOR: "#212732",
|
||||||
|
NODE_TITLE_COLOR: "#999",
|
||||||
|
NODE_SELECTED_TITLE_COLOR: "#e5eaf0",
|
||||||
|
NODE_TEXT_SIZE: 14,
|
||||||
|
NODE_TEXT_COLOR: "#bcc2c8",
|
||||||
|
NODE_SUBTEXT_SIZE: 12,
|
||||||
|
NODE_DEFAULT_COLOR: "#2e3440",
|
||||||
|
NODE_DEFAULT_BGCOLOR: "#161b22",
|
||||||
|
NODE_DEFAULT_BOXCOLOR: "#545d70",
|
||||||
|
NODE_DEFAULT_SHAPE: "box",
|
||||||
|
NODE_BOX_OUTLINE_COLOR: "#e5eaf0",
|
||||||
|
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||||
|
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||||
|
DEFAULT_GROUP_FONT: 24,
|
||||||
|
WIDGET_BGCOLOR: "#2e3440",
|
||||||
|
WIDGET_OUTLINE_COLOR: "#545d70",
|
||||||
|
WIDGET_TEXT_COLOR: "#bcc2c8",
|
||||||
|
WIDGET_SECONDARY_TEXT_COLOR: "#999",
|
||||||
|
LINK_COLOR: "#9A9",
|
||||||
|
EVENT_LINK_COLOR: "#A86",
|
||||||
|
CONNECTING_LINK_COLOR: "#AFA"
|
||||||
|
},
|
||||||
|
comfy_base: {
|
||||||
|
"fg-color": "#e5eaf0",
|
||||||
|
"bg-color": "#2e3440",
|
||||||
|
"comfy-menu-bg": "#161b22",
|
||||||
|
"comfy-input-bg": "#2e3440",
|
||||||
|
"input-text": "#bcc2c8",
|
||||||
|
"descrip-text": "#999",
|
||||||
|
"drag-text": "#ccc",
|
||||||
|
"error-text": "#ff4444",
|
||||||
|
"border-color": "#545d70",
|
||||||
|
"tr-even-bg-color": "#2e3440",
|
||||||
|
"tr-odd-bg-color": "#161b22",
|
||||||
|
"content-bg": "#545d70",
|
||||||
|
"content-fg": "#e5eaf0",
|
||||||
|
"content-hover-bg": "#2e3440",
|
||||||
|
"content-hover-fg": "#e5eaf0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
github: {
|
||||||
|
id: "github",
|
||||||
|
name: "Github",
|
||||||
|
colors: {
|
||||||
|
node_slot: {
|
||||||
|
BOOLEAN: "",
|
||||||
|
CLIP: "#eacb8b",
|
||||||
|
CLIP_VISION: "#A8DADC",
|
||||||
|
CLIP_VISION_OUTPUT: "#ad7452",
|
||||||
|
CONDITIONING: "#cf876f",
|
||||||
|
CONTROL_NET: "#00d78d",
|
||||||
|
CONTROL_NET_WEIGHTS: "",
|
||||||
|
FLOAT: "",
|
||||||
|
GLIGEN: "",
|
||||||
|
IMAGE: "#80a1c0",
|
||||||
|
IMAGEUPLOAD: "",
|
||||||
|
INT: "",
|
||||||
|
LATENT: "#b38ead",
|
||||||
|
LATENT_KEYFRAME: "",
|
||||||
|
MASK: "#a3bd8d",
|
||||||
|
MODEL: "#8978a7",
|
||||||
|
SAMPLER: "",
|
||||||
|
SIGMAS: "",
|
||||||
|
STRING: "",
|
||||||
|
STYLE_MODEL: "#C2FFAE",
|
||||||
|
T2I_ADAPTER_WEIGHTS: "",
|
||||||
|
TAESD: "#DCC274",
|
||||||
|
TIMESTEP_KEYFRAME: "",
|
||||||
|
UPSCALE_MODEL: "",
|
||||||
|
VAE: "#be616b"
|
||||||
|
},
|
||||||
|
litegraph_base: {
|
||||||
|
BACKGROUND_IMAGE: "",
|
||||||
|
CLEAR_BACKGROUND_COLOR: "#040506",
|
||||||
|
NODE_TITLE_COLOR: "#999",
|
||||||
|
NODE_SELECTED_TITLE_COLOR: "#e5eaf0",
|
||||||
|
NODE_TEXT_SIZE: 14,
|
||||||
|
NODE_TEXT_COLOR: "#bcc2c8",
|
||||||
|
NODE_SUBTEXT_SIZE: 12,
|
||||||
|
NODE_DEFAULT_COLOR: "#161b22",
|
||||||
|
NODE_DEFAULT_BGCOLOR: "#13171d",
|
||||||
|
NODE_DEFAULT_BOXCOLOR: "#30363d",
|
||||||
|
NODE_DEFAULT_SHAPE: "box",
|
||||||
|
NODE_BOX_OUTLINE_COLOR: "#e5eaf0",
|
||||||
|
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||||
|
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||||
|
DEFAULT_GROUP_FONT: 24,
|
||||||
|
WIDGET_BGCOLOR: "#161b22",
|
||||||
|
WIDGET_OUTLINE_COLOR: "#30363d",
|
||||||
|
WIDGET_TEXT_COLOR: "#bcc2c8",
|
||||||
|
WIDGET_SECONDARY_TEXT_COLOR: "#999",
|
||||||
|
LINK_COLOR: "#9A9",
|
||||||
|
EVENT_LINK_COLOR: "#A86",
|
||||||
|
CONNECTING_LINK_COLOR: "#AFA"
|
||||||
|
},
|
||||||
|
comfy_base: {
|
||||||
|
"fg-color": "#e5eaf0",
|
||||||
|
"bg-color": "#161b22",
|
||||||
|
"comfy-menu-bg": "#13171d",
|
||||||
|
"comfy-input-bg": "#161b22",
|
||||||
|
"input-text": "#bcc2c8",
|
||||||
|
"descrip-text": "#999",
|
||||||
|
"drag-text": "#ccc",
|
||||||
|
"error-text": "#ff4444",
|
||||||
|
"border-color": "#30363d",
|
||||||
|
"tr-even-bg-color": "#161b22",
|
||||||
|
"tr-odd-bg-color": "#13171d",
|
||||||
|
"content-bg": "#30363d",
|
||||||
|
"content-fg": "#e5eaf0",
|
||||||
|
"content-hover-bg": "#161b22",
|
||||||
|
"content-hover-fg": "#e5eaf0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const id = "Comfy.ColorPalette";
|
||||||
|
const idCustomColorPalettes = "Comfy.CustomColorPalettes";
|
||||||
|
const defaultColorPaletteId = "dark";
|
||||||
|
const els = {
|
||||||
|
select: null
|
||||||
|
};
|
||||||
|
const getCustomColorPalettes = /* @__PURE__ */ __name(() => {
|
||||||
|
return app.ui.settings.getSettingValue(idCustomColorPalettes, {});
|
||||||
|
}, "getCustomColorPalettes");
|
||||||
|
const setCustomColorPalettes = /* @__PURE__ */ __name((customColorPalettes) => {
|
||||||
|
return app.ui.settings.setSettingValue(
|
||||||
|
idCustomColorPalettes,
|
||||||
|
customColorPalettes
|
||||||
|
);
|
||||||
|
}, "setCustomColorPalettes");
|
||||||
|
const defaultColorPalette = colorPalettes[defaultColorPaletteId];
|
||||||
|
const getColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
|
||||||
|
if (!colorPaletteId) {
|
||||||
|
colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
|
||||||
|
}
|
||||||
|
if (colorPaletteId.startsWith("custom_")) {
|
||||||
|
colorPaletteId = colorPaletteId.substr(7);
|
||||||
|
let customColorPalettes = getCustomColorPalettes();
|
||||||
|
if (customColorPalettes[colorPaletteId]) {
|
||||||
|
return customColorPalettes[colorPaletteId];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return colorPalettes[colorPaletteId];
|
||||||
|
}, "getColorPalette");
|
||||||
|
const setColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
|
||||||
|
app.ui.settings.setSettingValue(id, colorPaletteId);
|
||||||
|
}, "setColorPalette");
|
||||||
|
app.registerExtension({
|
||||||
|
name: id,
|
||||||
|
init() {
|
||||||
|
LGraphCanvas.prototype.updateBackground = function(image, clearBackgroundColor) {
|
||||||
|
this._bg_img = new Image();
|
||||||
|
this._bg_img.name = image;
|
||||||
|
this._bg_img.src = image;
|
||||||
|
this._bg_img.onload = () => {
|
||||||
|
this.draw(true, true);
|
||||||
|
};
|
||||||
|
this.background_image = image;
|
||||||
|
this.clear_background = true;
|
||||||
|
this.clear_background_color = clearBackgroundColor;
|
||||||
|
this._pattern = null;
|
||||||
|
};
|
||||||
|
},
|
||||||
|
addCustomNodeDefs(node_defs) {
|
||||||
|
const sortObjectKeys = /* @__PURE__ */ __name((unordered) => {
|
||||||
|
return Object.keys(unordered).sort().reduce((obj, key) => {
|
||||||
|
obj[key] = unordered[key];
|
||||||
|
return obj;
|
||||||
|
}, {});
|
||||||
|
}, "sortObjectKeys");
|
||||||
|
function getSlotTypes() {
|
||||||
|
var types = [];
|
||||||
|
const defs = node_defs;
|
||||||
|
for (const nodeId in defs) {
|
||||||
|
const nodeData = defs[nodeId];
|
||||||
|
var inputs = nodeData["input"]["required"];
|
||||||
|
if (nodeData["input"]["optional"] !== void 0) {
|
||||||
|
inputs = Object.assign(
|
||||||
|
{},
|
||||||
|
nodeData["input"]["required"],
|
||||||
|
nodeData["input"]["optional"]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
for (const inputName in inputs) {
|
||||||
|
const inputData = inputs[inputName];
|
||||||
|
const type = inputData[0];
|
||||||
|
if (!Array.isArray(type)) {
|
||||||
|
types.push(type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const o in nodeData["output"]) {
|
||||||
|
const output = nodeData["output"][o];
|
||||||
|
types.push(output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return types;
|
||||||
|
}
|
||||||
|
__name(getSlotTypes, "getSlotTypes");
|
||||||
|
function completeColorPalette(colorPalette) {
|
||||||
|
var types = getSlotTypes();
|
||||||
|
for (const type of types) {
|
||||||
|
if (!colorPalette.colors.node_slot[type]) {
|
||||||
|
colorPalette.colors.node_slot[type] = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
colorPalette.colors.node_slot = sortObjectKeys(
|
||||||
|
colorPalette.colors.node_slot
|
||||||
|
);
|
||||||
|
return colorPalette;
|
||||||
|
}
|
||||||
|
__name(completeColorPalette, "completeColorPalette");
|
||||||
|
const getColorPaletteTemplate = /* @__PURE__ */ __name(async () => {
|
||||||
|
let colorPalette = {
|
||||||
|
id: "my_color_palette_unique_id",
|
||||||
|
name: "My Color Palette",
|
||||||
|
colors: {
|
||||||
|
node_slot: {},
|
||||||
|
litegraph_base: {},
|
||||||
|
comfy_base: {}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const defaultColorPalette2 = colorPalettes[defaultColorPaletteId];
|
||||||
|
for (const key in defaultColorPalette2.colors.litegraph_base) {
|
||||||
|
if (!colorPalette.colors.litegraph_base[key]) {
|
||||||
|
colorPalette.colors.litegraph_base[key] = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const key in defaultColorPalette2.colors.comfy_base) {
|
||||||
|
if (!colorPalette.colors.comfy_base[key]) {
|
||||||
|
colorPalette.colors.comfy_base[key] = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return completeColorPalette(colorPalette);
|
||||||
|
}, "getColorPaletteTemplate");
|
||||||
|
const addCustomColorPalette = /* @__PURE__ */ __name(async (colorPalette) => {
|
||||||
|
if (typeof colorPalette !== "object") {
|
||||||
|
useToastStore().addAlert("Invalid color palette.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!colorPalette.id) {
|
||||||
|
useToastStore().addAlert("Color palette missing id.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!colorPalette.name) {
|
||||||
|
useToastStore().addAlert("Color palette missing name.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!colorPalette.colors) {
|
||||||
|
useToastStore().addAlert("Color palette missing colors.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (colorPalette.colors.node_slot && typeof colorPalette.colors.node_slot !== "object") {
|
||||||
|
useToastStore().addAlert("Invalid color palette colors.node_slot.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const customColorPalettes = getCustomColorPalettes();
|
||||||
|
customColorPalettes[colorPalette.id] = colorPalette;
|
||||||
|
setCustomColorPalettes(customColorPalettes);
|
||||||
|
for (const option of els.select.childNodes) {
|
||||||
|
if (option.value === "custom_" + colorPalette.id) {
|
||||||
|
els.select.removeChild(option);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
els.select.append(
|
||||||
|
$el("option", {
|
||||||
|
textContent: colorPalette.name + " (custom)",
|
||||||
|
value: "custom_" + colorPalette.id,
|
||||||
|
selected: true
|
||||||
|
})
|
||||||
|
);
|
||||||
|
setColorPalette("custom_" + colorPalette.id);
|
||||||
|
await loadColorPalette(colorPalette);
|
||||||
|
}, "addCustomColorPalette");
|
||||||
|
const deleteCustomColorPalette = /* @__PURE__ */ __name(async (colorPaletteId) => {
|
||||||
|
const customColorPalettes = getCustomColorPalettes();
|
||||||
|
delete customColorPalettes[colorPaletteId];
|
||||||
|
setCustomColorPalettes(customColorPalettes);
|
||||||
|
for (const opt of els.select.childNodes) {
|
||||||
|
const option = opt;
|
||||||
|
if (option.value === defaultColorPaletteId) {
|
||||||
|
option.selected = true;
|
||||||
|
}
|
||||||
|
if (option.value === "custom_" + colorPaletteId) {
|
||||||
|
els.select.removeChild(option);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
setColorPalette(defaultColorPaletteId);
|
||||||
|
await loadColorPalette(getColorPalette());
|
||||||
|
}, "deleteCustomColorPalette");
|
||||||
|
const loadColorPalette = /* @__PURE__ */ __name(async (colorPalette) => {
|
||||||
|
colorPalette = await completeColorPalette(colorPalette);
|
||||||
|
if (colorPalette.colors) {
|
||||||
|
if (colorPalette.colors.node_slot) {
|
||||||
|
Object.assign(
|
||||||
|
app.canvas.default_connection_color_byType,
|
||||||
|
colorPalette.colors.node_slot
|
||||||
|
);
|
||||||
|
Object.assign(
|
||||||
|
LGraphCanvas.link_type_colors,
|
||||||
|
colorPalette.colors.node_slot
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (colorPalette.colors.litegraph_base) {
|
||||||
|
app.canvas.node_title_color = colorPalette.colors.litegraph_base.NODE_TITLE_COLOR;
|
||||||
|
app.canvas.default_link_color = colorPalette.colors.litegraph_base.LINK_COLOR;
|
||||||
|
for (const key in colorPalette.colors.litegraph_base) {
|
||||||
|
if (colorPalette.colors.litegraph_base.hasOwnProperty(key) && LiteGraph.hasOwnProperty(key)) {
|
||||||
|
LiteGraph[key] = colorPalette.colors.litegraph_base[key];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (colorPalette.colors.comfy_base) {
|
||||||
|
const rootStyle = document.documentElement.style;
|
||||||
|
for (const key in colorPalette.colors.comfy_base) {
|
||||||
|
rootStyle.setProperty(
|
||||||
|
"--" + key,
|
||||||
|
colorPalette.colors.comfy_base[key]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (colorPalette.colors.litegraph_base.NODE_BYPASS_BGCOLOR) {
|
||||||
|
app.bypassBgColor = colorPalette.colors.litegraph_base.NODE_BYPASS_BGCOLOR;
|
||||||
|
}
|
||||||
|
app.canvas.draw(true, true);
|
||||||
|
}
|
||||||
|
}, "loadColorPalette");
|
||||||
|
const fileInput = $el("input", {
|
||||||
|
type: "file",
|
||||||
|
accept: ".json",
|
||||||
|
style: { display: "none" },
|
||||||
|
parent: document.body,
|
||||||
|
onchange: /* @__PURE__ */ __name(() => {
|
||||||
|
const file = fileInput.files[0];
|
||||||
|
if (file.type === "application/json" || file.name.endsWith(".json")) {
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = async () => {
|
||||||
|
await addCustomColorPalette(JSON.parse(reader.result));
|
||||||
|
};
|
||||||
|
reader.readAsText(file);
|
||||||
|
}
|
||||||
|
}, "onchange")
|
||||||
|
});
|
||||||
|
app.ui.settings.addSetting({
|
||||||
|
id,
|
||||||
|
category: ["Comfy", "ColorPalette"],
|
||||||
|
name: "Color Palette",
|
||||||
|
type: /* @__PURE__ */ __name((name, setter, value) => {
|
||||||
|
const options = [
|
||||||
|
...Object.values(colorPalettes).map(
|
||||||
|
(c) => $el("option", {
|
||||||
|
textContent: c.name,
|
||||||
|
value: c.id,
|
||||||
|
selected: c.id === value
|
||||||
|
})
|
||||||
|
),
|
||||||
|
...Object.values(getCustomColorPalettes()).map(
|
||||||
|
(c) => $el("option", {
|
||||||
|
textContent: `${c.name} (custom)`,
|
||||||
|
value: `custom_${c.id}`,
|
||||||
|
selected: `custom_${c.id}` === value
|
||||||
|
})
|
||||||
|
)
|
||||||
|
];
|
||||||
|
els.select = $el(
|
||||||
|
"select",
|
||||||
|
{
|
||||||
|
style: {
|
||||||
|
marginBottom: "0.15rem",
|
||||||
|
width: "100%"
|
||||||
|
},
|
||||||
|
onchange: /* @__PURE__ */ __name((e) => {
|
||||||
|
setter(e.target.value);
|
||||||
|
}, "onchange")
|
||||||
|
},
|
||||||
|
options
|
||||||
|
);
|
||||||
|
return $el("tr", [
|
||||||
|
$el("td", [
|
||||||
|
els.select,
|
||||||
|
$el(
|
||||||
|
"div",
|
||||||
|
{
|
||||||
|
style: {
|
||||||
|
display: "grid",
|
||||||
|
gap: "4px",
|
||||||
|
gridAutoFlow: "column"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[
|
||||||
|
$el("input", {
|
||||||
|
type: "button",
|
||||||
|
value: "Export",
|
||||||
|
onclick: /* @__PURE__ */ __name(async () => {
|
||||||
|
const colorPaletteId = app.ui.settings.getSettingValue(
|
||||||
|
id,
|
||||||
|
defaultColorPaletteId
|
||||||
|
);
|
||||||
|
const colorPalette = await completeColorPalette(
|
||||||
|
getColorPalette(colorPaletteId)
|
||||||
|
);
|
||||||
|
const json = JSON.stringify(colorPalette, null, 2);
|
||||||
|
const blob = new Blob([json], { type: "application/json" });
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = $el("a", {
|
||||||
|
href: url,
|
||||||
|
download: colorPaletteId + ".json",
|
||||||
|
style: { display: "none" },
|
||||||
|
parent: document.body
|
||||||
|
});
|
||||||
|
a.click();
|
||||||
|
setTimeout(function() {
|
||||||
|
a.remove();
|
||||||
|
window.URL.revokeObjectURL(url);
|
||||||
|
}, 0);
|
||||||
|
}, "onclick")
|
||||||
|
}),
|
||||||
|
$el("input", {
|
||||||
|
type: "button",
|
||||||
|
value: "Import",
|
||||||
|
onclick: /* @__PURE__ */ __name(() => {
|
||||||
|
fileInput.click();
|
||||||
|
}, "onclick")
|
||||||
|
}),
|
||||||
|
$el("input", {
|
||||||
|
type: "button",
|
||||||
|
value: "Template",
|
||||||
|
onclick: /* @__PURE__ */ __name(async () => {
|
||||||
|
const colorPalette = await getColorPaletteTemplate();
|
||||||
|
const json = JSON.stringify(colorPalette, null, 2);
|
||||||
|
const blob = new Blob([json], { type: "application/json" });
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = $el("a", {
|
||||||
|
href: url,
|
||||||
|
download: "color_palette.json",
|
||||||
|
style: { display: "none" },
|
||||||
|
parent: document.body
|
||||||
|
});
|
||||||
|
a.click();
|
||||||
|
setTimeout(function() {
|
||||||
|
a.remove();
|
||||||
|
window.URL.revokeObjectURL(url);
|
||||||
|
}, 0);
|
||||||
|
}, "onclick")
|
||||||
|
}),
|
||||||
|
$el("input", {
|
||||||
|
type: "button",
|
||||||
|
value: "Delete",
|
||||||
|
onclick: /* @__PURE__ */ __name(async () => {
|
||||||
|
let colorPaletteId = app.ui.settings.getSettingValue(
|
||||||
|
id,
|
||||||
|
defaultColorPaletteId
|
||||||
|
);
|
||||||
|
if (colorPalettes[colorPaletteId]) {
|
||||||
|
useToastStore().addAlert(
|
||||||
|
"You cannot delete a built-in color palette."
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (colorPaletteId.startsWith("custom_")) {
|
||||||
|
colorPaletteId = colorPaletteId.substr(7);
|
||||||
|
}
|
||||||
|
await deleteCustomColorPalette(colorPaletteId);
|
||||||
|
}, "onclick")
|
||||||
|
})
|
||||||
|
]
|
||||||
|
)
|
||||||
|
])
|
||||||
|
]);
|
||||||
|
}, "type"),
|
||||||
|
defaultValue: defaultColorPaletteId,
|
||||||
|
async onChange(value) {
|
||||||
|
if (!value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let palette = colorPalettes[value];
|
||||||
|
if (palette) {
|
||||||
|
await loadColorPalette(palette);
|
||||||
|
} else if (value.startsWith("custom_")) {
|
||||||
|
value = value.substr(7);
|
||||||
|
let customColorPalettes = getCustomColorPalettes();
|
||||||
|
if (customColorPalettes[value]) {
|
||||||
|
palette = customColorPalettes[value];
|
||||||
|
await loadColorPalette(customColorPalettes[value]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let { BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR } = palette.colors.litegraph_base;
|
||||||
|
if (BACKGROUND_IMAGE === void 0 || CLEAR_BACKGROUND_COLOR === void 0) {
|
||||||
|
const base = colorPalettes["dark"].colors.litegraph_base;
|
||||||
|
BACKGROUND_IMAGE = base.BACKGROUND_IMAGE;
|
||||||
|
CLEAR_BACKGROUND_COLOR = base.CLEAR_BACKGROUND_COLOR;
|
||||||
|
}
|
||||||
|
app.canvas.updateBackground(BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
window.comfyAPI = window.comfyAPI || {};
|
||||||
|
window.comfyAPI.colorPalette = window.comfyAPI.colorPalette || {};
|
||||||
|
window.comfyAPI.colorPalette.defaultColorPalette = defaultColorPalette;
|
||||||
|
window.comfyAPI.colorPalette.getColorPalette = getColorPalette;
|
||||||
|
export {
|
||||||
|
defaultColorPalette as d,
|
||||||
|
getColorPalette as g
|
||||||
|
};
|
||||||
|
//# sourceMappingURL=colorPalette-D5oi2-2V.js.map
|
||||||
1
web/assets/colorPalette-D5oi2-2V.js.map
generated
vendored
Normal file
1
web/assets/colorPalette-D5oi2-2V.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
web/assets/index-BD-Ia1C4.js.map
generated
vendored
1
web/assets/index-BD-Ia1C4.js.map
generated
vendored
File diff suppressed because one or more lines are too long
5673
web/assets/index-_5czGnTA.css → web/assets/index-BHJGjcJh.css
generated
vendored
5673
web/assets/index-_5czGnTA.css → web/assets/index-BHJGjcJh.css
generated
vendored
File diff suppressed because it is too large
Load Diff
2018
web/assets/index-BD-Ia1C4.js → web/assets/index-BMC1ey-i.js
generated
vendored
2018
web/assets/index-BD-Ia1C4.js → web/assets/index-BMC1ey-i.js
generated
vendored
File diff suppressed because it is too large
Load Diff
1
web/assets/index-BMC1ey-i.js.map
generated
vendored
Normal file
1
web/assets/index-BMC1ey-i.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
0
web/assets/index-DjWyclij.css → web/assets/index-BRhY6FpL.css
generated
vendored
0
web/assets/index-DjWyclij.css → web/assets/index-BRhY6FpL.css
generated
vendored
1
web/assets/index-CI3N807S.js.map
generated
vendored
1
web/assets/index-CI3N807S.js.map
generated
vendored
File diff suppressed because one or more lines are too long
164716
web/assets/index-CI3N807S.js → web/assets/index-DGAbdBYF.js
generated
vendored
164716
web/assets/index-CI3N807S.js → web/assets/index-DGAbdBYF.js
generated
vendored
File diff suppressed because one or more lines are too long
1
web/assets/index-DGAbdBYF.js.map
generated
vendored
Normal file
1
web/assets/index-DGAbdBYF.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
2602
web/assets/sorted-custom-node-map.json
generated
vendored
Normal file
2602
web/assets/sorted-custom-node-map.json
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
34
web/assets/userSelection-BGzn1LuN.css → web/assets/userSelection-CmI-fOSC.css
generated
vendored
34
web/assets/userSelection-BGzn1LuN.css → web/assets/userSelection-CmI-fOSC.css
generated
vendored
@@ -1,3 +1,37 @@
|
|||||||
|
.lds-ring {
|
||||||
|
display: inline-block;
|
||||||
|
position: relative;
|
||||||
|
width: 1em;
|
||||||
|
height: 1em;
|
||||||
|
}
|
||||||
|
.lds-ring div {
|
||||||
|
box-sizing: border-box;
|
||||||
|
display: block;
|
||||||
|
position: absolute;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
border: 0.15em solid #fff;
|
||||||
|
border-radius: 50%;
|
||||||
|
animation: lds-ring 1.2s cubic-bezier(0.5, 0, 0.5, 1) infinite;
|
||||||
|
border-color: #fff transparent transparent transparent;
|
||||||
|
}
|
||||||
|
.lds-ring div:nth-child(1) {
|
||||||
|
animation-delay: -0.45s;
|
||||||
|
}
|
||||||
|
.lds-ring div:nth-child(2) {
|
||||||
|
animation-delay: -0.3s;
|
||||||
|
}
|
||||||
|
.lds-ring div:nth-child(3) {
|
||||||
|
animation-delay: -0.15s;
|
||||||
|
}
|
||||||
|
@keyframes lds-ring {
|
||||||
|
0% {
|
||||||
|
transform: rotate(0deg);
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
transform: rotate(360deg);
|
||||||
|
}
|
||||||
|
}
|
||||||
.comfy-user-selection {
|
.comfy-user-selection {
|
||||||
width: 100vw;
|
width: 100vw;
|
||||||
height: 100vh;
|
height: 100vh;
|
||||||
1
web/assets/userSelection-CyXKCVy3.js.map
generated
vendored
1
web/assets/userSelection-CyXKCVy3.js.map
generated
vendored
File diff suppressed because one or more lines are too long
13
web/assets/userSelection-CyXKCVy3.js → web/assets/userSelection-Duxc-t_S.js
generated
vendored
13
web/assets/userSelection-CyXKCVy3.js → web/assets/userSelection-Duxc-t_S.js
generated
vendored
@@ -1,6 +1,15 @@
|
|||||||
var __defProp = Object.defineProperty;
|
var __defProp = Object.defineProperty;
|
||||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
import { j as createSpinner, h as api, $ as $el } from "./index-CI3N807S.js";
|
import { b4 as api, ca as $el } from "./index-DGAbdBYF.js";
|
||||||
|
function createSpinner() {
|
||||||
|
const div = document.createElement("div");
|
||||||
|
div.innerHTML = `<div class="lds-ring"><div></div><div></div><div></div><div></div></div>`;
|
||||||
|
return div.firstElementChild;
|
||||||
|
}
|
||||||
|
__name(createSpinner, "createSpinner");
|
||||||
|
window.comfyAPI = window.comfyAPI || {};
|
||||||
|
window.comfyAPI.spinner = window.comfyAPI.spinner || {};
|
||||||
|
window.comfyAPI.spinner.createSpinner = createSpinner;
|
||||||
class UserSelectionScreen {
|
class UserSelectionScreen {
|
||||||
static {
|
static {
|
||||||
__name(this, "UserSelectionScreen");
|
__name(this, "UserSelectionScreen");
|
||||||
@@ -117,4 +126,4 @@ window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
|
|||||||
export {
|
export {
|
||||||
UserSelectionScreen
|
UserSelectionScreen
|
||||||
};
|
};
|
||||||
//# sourceMappingURL=userSelection-CyXKCVy3.js.map
|
//# sourceMappingURL=userSelection-Duxc-t_S.js.map
|
||||||
1
web/assets/userSelection-Duxc-t_S.js.map
generated
vendored
Normal file
1
web/assets/userSelection-Duxc-t_S.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
756
web/assets/widgetInputs-DdoWwzg5.js
generated
vendored
Normal file
756
web/assets/widgetInputs-DdoWwzg5.js
generated
vendored
Normal file
@@ -0,0 +1,756 @@
|
|||||||
|
var __defProp = Object.defineProperty;
|
||||||
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
|
import { l as LGraphNode, k as app, cf as applyTextReplacements, ce as ComfyWidgets, ci as addValueControlWidgets, z as LiteGraph } from "./index-DGAbdBYF.js";
|
||||||
|
const CONVERTED_TYPE = "converted-widget";
|
||||||
|
const VALID_TYPES = [
|
||||||
|
"STRING",
|
||||||
|
"combo",
|
||||||
|
"number",
|
||||||
|
"toggle",
|
||||||
|
"BOOLEAN",
|
||||||
|
"text",
|
||||||
|
"string"
|
||||||
|
];
|
||||||
|
const CONFIG = Symbol();
|
||||||
|
const GET_CONFIG = Symbol();
|
||||||
|
const TARGET = Symbol();
|
||||||
|
const replacePropertyName = "Run widget replace on values";
|
||||||
|
class PrimitiveNode extends LGraphNode {
|
||||||
|
static {
|
||||||
|
__name(this, "PrimitiveNode");
|
||||||
|
}
|
||||||
|
controlValues;
|
||||||
|
lastType;
|
||||||
|
static category;
|
||||||
|
constructor(title) {
|
||||||
|
super(title);
|
||||||
|
this.addOutput("connect to widget input", "*");
|
||||||
|
this.serialize_widgets = true;
|
||||||
|
this.isVirtualNode = true;
|
||||||
|
if (!this.properties || !(replacePropertyName in this.properties)) {
|
||||||
|
this.addProperty(replacePropertyName, false, "boolean");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
applyToGraph(extraLinks = []) {
|
||||||
|
if (!this.outputs[0].links?.length) return;
|
||||||
|
function get_links(node) {
|
||||||
|
let links2 = [];
|
||||||
|
for (const l of node.outputs[0].links) {
|
||||||
|
const linkInfo = app.graph.links[l];
|
||||||
|
const n = node.graph.getNodeById(linkInfo.target_id);
|
||||||
|
if (n.type == "Reroute") {
|
||||||
|
links2 = links2.concat(get_links(n));
|
||||||
|
} else {
|
||||||
|
links2.push(l);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return links2;
|
||||||
|
}
|
||||||
|
__name(get_links, "get_links");
|
||||||
|
let links = [
|
||||||
|
...get_links(this).map((l) => app.graph.links[l]),
|
||||||
|
...extraLinks
|
||||||
|
];
|
||||||
|
let v = this.widgets?.[0].value;
|
||||||
|
if (v && this.properties[replacePropertyName]) {
|
||||||
|
v = applyTextReplacements(app, v);
|
||||||
|
}
|
||||||
|
for (const linkInfo of links) {
|
||||||
|
const node = this.graph.getNodeById(linkInfo.target_id);
|
||||||
|
const input = node.inputs[linkInfo.target_slot];
|
||||||
|
let widget;
|
||||||
|
if (input.widget[TARGET]) {
|
||||||
|
widget = input.widget[TARGET];
|
||||||
|
} else {
|
||||||
|
const widgetName = input.widget.name;
|
||||||
|
if (widgetName) {
|
||||||
|
widget = node.widgets.find((w) => w.name === widgetName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (widget) {
|
||||||
|
widget.value = v;
|
||||||
|
if (widget.callback) {
|
||||||
|
widget.callback(
|
||||||
|
widget.value,
|
||||||
|
app.canvas,
|
||||||
|
node,
|
||||||
|
app.canvas.graph_mouse,
|
||||||
|
{}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
refreshComboInNode() {
|
||||||
|
const widget = this.widgets?.[0];
|
||||||
|
if (widget?.type === "combo") {
|
||||||
|
widget.options.values = this.outputs[0].widget[GET_CONFIG]()[0];
|
||||||
|
if (!widget.options.values.includes(widget.value)) {
|
||||||
|
widget.value = widget.options.values[0];
|
||||||
|
widget.callback(widget.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
onAfterGraphConfigured() {
|
||||||
|
if (this.outputs[0].links?.length && !this.widgets?.length) {
|
||||||
|
if (!this.#onFirstConnection()) return;
|
||||||
|
if (this.widgets) {
|
||||||
|
for (let i = 0; i < this.widgets_values.length; i++) {
|
||||||
|
const w = this.widgets[i];
|
||||||
|
if (w) {
|
||||||
|
w.value = this.widgets_values[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.#mergeWidgetConfig();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
onConnectionsChange(_, index, connected) {
|
||||||
|
if (app.configuringGraph) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const links = this.outputs[0].links;
|
||||||
|
if (connected) {
|
||||||
|
if (links?.length && !this.widgets?.length) {
|
||||||
|
this.#onFirstConnection();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
this.#mergeWidgetConfig();
|
||||||
|
if (!links?.length) {
|
||||||
|
this.onLastDisconnect();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
onConnectOutput(slot, type, input, target_node, target_slot) {
|
||||||
|
if (!input.widget) {
|
||||||
|
if (!(input.type in ComfyWidgets)) return false;
|
||||||
|
}
|
||||||
|
if (this.outputs[slot].links?.length) {
|
||||||
|
const valid = this.#isValidConnection(input);
|
||||||
|
if (valid) {
|
||||||
|
this.applyToGraph([{ target_id: target_node.id, target_slot }]);
|
||||||
|
}
|
||||||
|
return valid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#onFirstConnection(recreating) {
|
||||||
|
if (!this.outputs[0].links) {
|
||||||
|
this.onLastDisconnect();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const linkId = this.outputs[0].links[0];
|
||||||
|
const link = this.graph.links[linkId];
|
||||||
|
if (!link) return;
|
||||||
|
const theirNode = this.graph.getNodeById(link.target_id);
|
||||||
|
if (!theirNode || !theirNode.inputs) return;
|
||||||
|
const input = theirNode.inputs[link.target_slot];
|
||||||
|
if (!input) return;
|
||||||
|
let widget;
|
||||||
|
if (!input.widget) {
|
||||||
|
if (!(input.type in ComfyWidgets)) return;
|
||||||
|
widget = { name: input.name, [GET_CONFIG]: () => [input.type, {}] };
|
||||||
|
} else {
|
||||||
|
widget = input.widget;
|
||||||
|
}
|
||||||
|
const config = widget[GET_CONFIG]?.();
|
||||||
|
if (!config) return;
|
||||||
|
const { type } = getWidgetType(config);
|
||||||
|
this.outputs[0].type = type;
|
||||||
|
this.outputs[0].name = type;
|
||||||
|
this.outputs[0].widget = widget;
|
||||||
|
this.#createWidget(
|
||||||
|
widget[CONFIG] ?? config,
|
||||||
|
theirNode,
|
||||||
|
widget.name,
|
||||||
|
recreating,
|
||||||
|
widget[TARGET]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
#createWidget(inputData, node, widgetName, recreating, targetWidget) {
|
||||||
|
let type = inputData[0];
|
||||||
|
if (type instanceof Array) {
|
||||||
|
type = "COMBO";
|
||||||
|
}
|
||||||
|
const size = this.size;
|
||||||
|
let widget;
|
||||||
|
if (type in ComfyWidgets) {
|
||||||
|
widget = (ComfyWidgets[type](this, "value", inputData, app) || {}).widget;
|
||||||
|
} else {
|
||||||
|
widget = this.addWidget(type, "value", null, () => {
|
||||||
|
}, {});
|
||||||
|
}
|
||||||
|
if (targetWidget) {
|
||||||
|
widget.value = targetWidget.value;
|
||||||
|
} else if (node?.widgets && widget) {
|
||||||
|
const theirWidget = node.widgets.find((w) => w.name === widgetName);
|
||||||
|
if (theirWidget) {
|
||||||
|
widget.value = theirWidget.value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!inputData?.[1]?.control_after_generate && (widget.type === "number" || widget.type === "combo")) {
|
||||||
|
let control_value = this.widgets_values?.[1];
|
||||||
|
if (!control_value) {
|
||||||
|
control_value = "fixed";
|
||||||
|
}
|
||||||
|
addValueControlWidgets(
|
||||||
|
this,
|
||||||
|
widget,
|
||||||
|
control_value,
|
||||||
|
void 0,
|
||||||
|
inputData
|
||||||
|
);
|
||||||
|
let filter = this.widgets_values?.[2];
|
||||||
|
if (filter && this.widgets.length === 3) {
|
||||||
|
this.widgets[2].value = filter;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const controlValues = this.controlValues;
|
||||||
|
if (this.lastType === this.widgets[0].type && controlValues?.length === this.widgets.length - 1) {
|
||||||
|
for (let i = 0; i < controlValues.length; i++) {
|
||||||
|
this.widgets[i + 1].value = controlValues[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const callback = widget.callback;
|
||||||
|
const self = this;
|
||||||
|
widget.callback = function() {
|
||||||
|
const r = callback ? callback.apply(this, arguments) : void 0;
|
||||||
|
self.applyToGraph();
|
||||||
|
return r;
|
||||||
|
};
|
||||||
|
this.size = [
|
||||||
|
Math.max(this.size[0], size[0]),
|
||||||
|
Math.max(this.size[1], size[1])
|
||||||
|
];
|
||||||
|
if (!recreating) {
|
||||||
|
const sz = this.computeSize();
|
||||||
|
if (this.size[0] < sz[0]) {
|
||||||
|
this.size[0] = sz[0];
|
||||||
|
}
|
||||||
|
if (this.size[1] < sz[1]) {
|
||||||
|
this.size[1] = sz[1];
|
||||||
|
}
|
||||||
|
requestAnimationFrame(() => {
|
||||||
|
if (this.onResize) {
|
||||||
|
this.onResize(this.size);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
recreateWidget() {
|
||||||
|
const values = this.widgets?.map((w) => w.value);
|
||||||
|
this.#removeWidgets();
|
||||||
|
this.#onFirstConnection(true);
|
||||||
|
if (values?.length) {
|
||||||
|
for (let i = 0; i < this.widgets?.length; i++)
|
||||||
|
this.widgets[i].value = values[i];
|
||||||
|
}
|
||||||
|
return this.widgets?.[0];
|
||||||
|
}
|
||||||
|
#mergeWidgetConfig() {
|
||||||
|
const output = this.outputs[0];
|
||||||
|
const links = output.links;
|
||||||
|
const hasConfig = !!output.widget[CONFIG];
|
||||||
|
if (hasConfig) {
|
||||||
|
delete output.widget[CONFIG];
|
||||||
|
}
|
||||||
|
if (links?.length < 2 && hasConfig) {
|
||||||
|
if (links.length) {
|
||||||
|
this.recreateWidget();
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const config1 = output.widget[GET_CONFIG]();
|
||||||
|
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
|
||||||
|
if (!isNumber) return;
|
||||||
|
for (const linkId of links) {
|
||||||
|
const link = app.graph.links[linkId];
|
||||||
|
if (!link) continue;
|
||||||
|
const theirNode = app.graph.getNodeById(link.target_id);
|
||||||
|
const theirInput = theirNode.inputs[link.target_slot];
|
||||||
|
this.#isValidConnection(theirInput, hasConfig);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
isValidWidgetLink(originSlot, targetNode, targetWidget) {
|
||||||
|
const config2 = getConfig.call(targetNode, targetWidget.name) ?? [
|
||||||
|
targetWidget.type,
|
||||||
|
targetWidget.options || {}
|
||||||
|
];
|
||||||
|
if (!isConvertibleWidget(targetWidget, config2)) return false;
|
||||||
|
const output = this.outputs[originSlot];
|
||||||
|
if (!(output.widget?.[CONFIG] ?? output.widget?.[GET_CONFIG]())) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return !!mergeIfValid.call(this, output, config2);
|
||||||
|
}
|
||||||
|
#isValidConnection(input, forceUpdate) {
|
||||||
|
const output = this.outputs[0];
|
||||||
|
const config2 = input.widget[GET_CONFIG]();
|
||||||
|
return !!mergeIfValid.call(
|
||||||
|
this,
|
||||||
|
output,
|
||||||
|
config2,
|
||||||
|
forceUpdate,
|
||||||
|
this.recreateWidget
|
||||||
|
);
|
||||||
|
}
|
||||||
|
#removeWidgets() {
|
||||||
|
if (this.widgets) {
|
||||||
|
for (const w of this.widgets) {
|
||||||
|
if (w.onRemove) {
|
||||||
|
w.onRemove();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.controlValues = [];
|
||||||
|
this.lastType = this.widgets[0]?.type;
|
||||||
|
for (let i = 1; i < this.widgets.length; i++) {
|
||||||
|
this.controlValues.push(this.widgets[i].value);
|
||||||
|
}
|
||||||
|
setTimeout(() => {
|
||||||
|
delete this.lastType;
|
||||||
|
delete this.controlValues;
|
||||||
|
}, 15);
|
||||||
|
this.widgets.length = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
onLastDisconnect() {
|
||||||
|
this.outputs[0].type = "*";
|
||||||
|
this.outputs[0].name = "connect to widget input";
|
||||||
|
delete this.outputs[0].widget;
|
||||||
|
this.#removeWidgets();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function getWidgetConfig(slot) {
|
||||||
|
return slot.widget[CONFIG] ?? slot.widget[GET_CONFIG]();
|
||||||
|
}
|
||||||
|
__name(getWidgetConfig, "getWidgetConfig");
|
||||||
|
function getConfig(widgetName) {
|
||||||
|
const { nodeData } = this.constructor;
|
||||||
|
return nodeData?.input?.required?.[widgetName] ?? nodeData?.input?.optional?.[widgetName];
|
||||||
|
}
|
||||||
|
__name(getConfig, "getConfig");
|
||||||
|
function isConvertibleWidget(widget, config) {
|
||||||
|
return (VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0])) && !widget.options?.forceInput;
|
||||||
|
}
|
||||||
|
__name(isConvertibleWidget, "isConvertibleWidget");
|
||||||
|
function hideWidget(node, widget, suffix = "") {
|
||||||
|
if (widget.type?.startsWith(CONVERTED_TYPE)) return;
|
||||||
|
widget.origType = widget.type;
|
||||||
|
widget.origComputeSize = widget.computeSize;
|
||||||
|
widget.origSerializeValue = widget.serializeValue;
|
||||||
|
widget.computeSize = () => [0, -4];
|
||||||
|
widget.type = CONVERTED_TYPE + suffix;
|
||||||
|
widget.serializeValue = () => {
|
||||||
|
if (!node.inputs) {
|
||||||
|
return void 0;
|
||||||
|
}
|
||||||
|
let node_input = node.inputs.find((i) => i.widget?.name === widget.name);
|
||||||
|
if (!node_input || !node_input.link) {
|
||||||
|
return void 0;
|
||||||
|
}
|
||||||
|
return widget.origSerializeValue ? widget.origSerializeValue() : widget.value;
|
||||||
|
};
|
||||||
|
if (widget.linkedWidgets) {
|
||||||
|
for (const w of widget.linkedWidgets) {
|
||||||
|
hideWidget(node, w, ":" + widget.name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__name(hideWidget, "hideWidget");
|
||||||
|
function showWidget(widget) {
|
||||||
|
widget.type = widget.origType;
|
||||||
|
widget.computeSize = widget.origComputeSize;
|
||||||
|
widget.serializeValue = widget.origSerializeValue;
|
||||||
|
delete widget.origType;
|
||||||
|
delete widget.origComputeSize;
|
||||||
|
delete widget.origSerializeValue;
|
||||||
|
if (widget.linkedWidgets) {
|
||||||
|
for (const w of widget.linkedWidgets) {
|
||||||
|
showWidget(w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__name(showWidget, "showWidget");
|
||||||
|
function convertToInput(node, widget, config) {
|
||||||
|
hideWidget(node, widget);
|
||||||
|
const { type } = getWidgetType(config);
|
||||||
|
const sz = node.size;
|
||||||
|
const inputIsOptional = !!widget.options?.inputIsOptional;
|
||||||
|
const input = node.addInput(widget.name, type, {
|
||||||
|
widget: { name: widget.name, [GET_CONFIG]: () => config },
|
||||||
|
...inputIsOptional ? { shape: LiteGraph.SlotShape.HollowCircle } : {}
|
||||||
|
});
|
||||||
|
for (const widget2 of node.widgets) {
|
||||||
|
widget2.last_y += LiteGraph.NODE_SLOT_HEIGHT;
|
||||||
|
}
|
||||||
|
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]);
|
||||||
|
return input;
|
||||||
|
}
|
||||||
|
__name(convertToInput, "convertToInput");
|
||||||
|
function convertToWidget(node, widget) {
|
||||||
|
showWidget(widget);
|
||||||
|
const sz = node.size;
|
||||||
|
node.removeInput(node.inputs.findIndex((i) => i.widget?.name === widget.name));
|
||||||
|
for (const widget2 of node.widgets) {
|
||||||
|
widget2.last_y -= LiteGraph.NODE_SLOT_HEIGHT;
|
||||||
|
}
|
||||||
|
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]);
|
||||||
|
}
|
||||||
|
__name(convertToWidget, "convertToWidget");
|
||||||
|
function getWidgetType(config) {
|
||||||
|
let type = config[0];
|
||||||
|
if (type instanceof Array) {
|
||||||
|
type = "COMBO";
|
||||||
|
}
|
||||||
|
return { type };
|
||||||
|
}
|
||||||
|
__name(getWidgetType, "getWidgetType");
|
||||||
|
function isValidCombo(combo, obj) {
|
||||||
|
if (!(obj instanceof Array)) {
|
||||||
|
console.log(`connection rejected: tried to connect combo to ${obj}`);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (combo.length !== obj.length) {
|
||||||
|
console.log(`connection rejected: combo lists dont match`);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (combo.find((v, i) => obj[i] !== v)) {
|
||||||
|
console.log(`connection rejected: combo lists dont match`);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
__name(isValidCombo, "isValidCombo");
|
||||||
|
function isPrimitiveNode(node) {
|
||||||
|
return node.type === "PrimitiveNode";
|
||||||
|
}
|
||||||
|
__name(isPrimitiveNode, "isPrimitiveNode");
|
||||||
|
function setWidgetConfig(slot, config, target) {
|
||||||
|
if (!slot.widget) return;
|
||||||
|
if (config) {
|
||||||
|
slot.widget[GET_CONFIG] = () => config;
|
||||||
|
slot.widget[TARGET] = target;
|
||||||
|
} else {
|
||||||
|
delete slot.widget;
|
||||||
|
}
|
||||||
|
if (slot.link) {
|
||||||
|
const link = app.graph.links[slot.link];
|
||||||
|
if (link) {
|
||||||
|
const originNode = app.graph.getNodeById(link.origin_id);
|
||||||
|
if (isPrimitiveNode(originNode)) {
|
||||||
|
if (config) {
|
||||||
|
originNode.recreateWidget();
|
||||||
|
} else if (!app.configuringGraph) {
|
||||||
|
originNode.disconnectOutput(0);
|
||||||
|
originNode.onLastDisconnect();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__name(setWidgetConfig, "setWidgetConfig");
|
||||||
|
function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
|
||||||
|
if (!config1) {
|
||||||
|
config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
|
||||||
|
}
|
||||||
|
if (config1[0] instanceof Array) {
|
||||||
|
if (!isValidCombo(config1[0], config2[0])) return;
|
||||||
|
} else if (config1[0] !== config2[0]) {
|
||||||
|
console.log(`connection rejected: types dont match`, config1[0], config2[0]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const keys = /* @__PURE__ */ new Set([
|
||||||
|
...Object.keys(config1[1] ?? {}),
|
||||||
|
...Object.keys(config2[1] ?? {})
|
||||||
|
]);
|
||||||
|
let customConfig;
|
||||||
|
const getCustomConfig = /* @__PURE__ */ __name(() => {
|
||||||
|
if (!customConfig) {
|
||||||
|
if (typeof structuredClone === "undefined") {
|
||||||
|
customConfig = JSON.parse(JSON.stringify(config1[1] ?? {}));
|
||||||
|
} else {
|
||||||
|
customConfig = structuredClone(config1[1] ?? {});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return customConfig;
|
||||||
|
}, "getCustomConfig");
|
||||||
|
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
|
||||||
|
for (const k of keys.values()) {
|
||||||
|
if (k !== "default" && k !== "forceInput" && k !== "defaultInput" && k !== "control_after_generate" && k !== "multiline" && k !== "tooltip") {
|
||||||
|
let v1 = config1[1][k];
|
||||||
|
let v2 = config2[1]?.[k];
|
||||||
|
if (v1 === v2 || !v1 && !v2) continue;
|
||||||
|
if (isNumber) {
|
||||||
|
if (k === "min") {
|
||||||
|
const theirMax = config2[1]?.["max"];
|
||||||
|
if (theirMax != null && v1 > theirMax) {
|
||||||
|
console.log("connection rejected: min > max", v1, theirMax);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2);
|
||||||
|
continue;
|
||||||
|
} else if (k === "max") {
|
||||||
|
const theirMin = config2[1]?.["min"];
|
||||||
|
if (theirMin != null && v1 < theirMin) {
|
||||||
|
console.log("connection rejected: max < min", v1, theirMin);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2);
|
||||||
|
continue;
|
||||||
|
} else if (k === "step") {
|
||||||
|
let step;
|
||||||
|
if (v1 == null) {
|
||||||
|
step = v2;
|
||||||
|
} else if (v2 == null) {
|
||||||
|
step = v1;
|
||||||
|
} else {
|
||||||
|
if (v1 < v2) {
|
||||||
|
const a = v2;
|
||||||
|
v2 = v1;
|
||||||
|
v1 = a;
|
||||||
|
}
|
||||||
|
if (v1 % v2) {
|
||||||
|
console.log(
|
||||||
|
"connection rejected: steps not divisible",
|
||||||
|
"current:",
|
||||||
|
v1,
|
||||||
|
"new:",
|
||||||
|
v2
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
step = v1;
|
||||||
|
}
|
||||||
|
getCustomConfig()[k] = step;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
console.log(`connection rejected: config ${k} values dont match`, v1, v2);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (customConfig || forceUpdate) {
|
||||||
|
if (customConfig) {
|
||||||
|
output.widget[CONFIG] = [config1[0], customConfig];
|
||||||
|
}
|
||||||
|
const widget = recreateWidget?.call(this);
|
||||||
|
if (widget) {
|
||||||
|
const min = widget.options.min;
|
||||||
|
const max = widget.options.max;
|
||||||
|
if (min != null && widget.value < min) widget.value = min;
|
||||||
|
if (max != null && widget.value > max) widget.value = max;
|
||||||
|
widget.callback(widget.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return { customConfig };
|
||||||
|
}
|
||||||
|
__name(mergeIfValid, "mergeIfValid");
|
||||||
|
let useConversionSubmenusSetting;
|
||||||
|
app.registerExtension({
|
||||||
|
name: "Comfy.WidgetInputs",
|
||||||
|
init() {
|
||||||
|
useConversionSubmenusSetting = app.ui.settings.addSetting({
|
||||||
|
id: "Comfy.NodeInputConversionSubmenus",
|
||||||
|
name: "In the node context menu, place the entries that convert between input/widget in sub-menus.",
|
||||||
|
type: "boolean",
|
||||||
|
defaultValue: true
|
||||||
|
});
|
||||||
|
},
|
||||||
|
async beforeRegisterNodeDef(nodeType, nodeData, app2) {
|
||||||
|
const origGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions;
|
||||||
|
nodeType.prototype.convertWidgetToInput = function(widget) {
|
||||||
|
const config = getConfig.call(this, widget.name) ?? [
|
||||||
|
widget.type,
|
||||||
|
widget.options || {}
|
||||||
|
];
|
||||||
|
if (!isConvertibleWidget(widget, config)) return false;
|
||||||
|
if (widget.type?.startsWith(CONVERTED_TYPE)) return false;
|
||||||
|
convertToInput(this, widget, config);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
nodeType.prototype.getExtraMenuOptions = function(_, options) {
|
||||||
|
const r = origGetExtraMenuOptions ? origGetExtraMenuOptions.apply(this, arguments) : void 0;
|
||||||
|
if (this.widgets) {
|
||||||
|
let toInput = [];
|
||||||
|
let toWidget = [];
|
||||||
|
for (const w of this.widgets) {
|
||||||
|
if (w.options?.forceInput) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (w.type === CONVERTED_TYPE) {
|
||||||
|
toWidget.push({
|
||||||
|
content: `Convert ${w.name} to widget`,
|
||||||
|
callback: /* @__PURE__ */ __name(() => convertToWidget(this, w), "callback")
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
const config = getConfig.call(this, w.name) ?? [
|
||||||
|
w.type,
|
||||||
|
w.options || {}
|
||||||
|
];
|
||||||
|
if (isConvertibleWidget(w, config)) {
|
||||||
|
toInput.push({
|
||||||
|
content: `Convert ${w.name} to input`,
|
||||||
|
callback: /* @__PURE__ */ __name(() => convertToInput(this, w, config), "callback")
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (toInput.length) {
|
||||||
|
if (useConversionSubmenusSetting.value) {
|
||||||
|
options.push({
|
||||||
|
content: "Convert Widget to Input",
|
||||||
|
submenu: {
|
||||||
|
options: toInput
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
options.push(...toInput, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (toWidget.length) {
|
||||||
|
if (useConversionSubmenusSetting.value) {
|
||||||
|
options.push({
|
||||||
|
content: "Convert Input to Widget",
|
||||||
|
submenu: {
|
||||||
|
options: toWidget
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
options.push(...toWidget, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
};
|
||||||
|
nodeType.prototype.onGraphConfigured = function() {
|
||||||
|
if (!this.inputs) return;
|
||||||
|
this.widgets ??= [];
|
||||||
|
for (const input of this.inputs) {
|
||||||
|
if (input.widget) {
|
||||||
|
if (!input.widget[GET_CONFIG]) {
|
||||||
|
input.widget[GET_CONFIG] = () => getConfig.call(this, input.widget.name);
|
||||||
|
}
|
||||||
|
if (input.widget.config) {
|
||||||
|
if (input.widget.config[0] instanceof Array) {
|
||||||
|
input.type = "COMBO";
|
||||||
|
const link = app2.graph.links[input.link];
|
||||||
|
if (link) {
|
||||||
|
link.type = input.type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete input.widget.config;
|
||||||
|
}
|
||||||
|
const w = this.widgets.find((w2) => w2.name === input.widget.name);
|
||||||
|
if (w) {
|
||||||
|
hideWidget(this, w);
|
||||||
|
} else {
|
||||||
|
convertToWidget(this, input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
|
||||||
|
nodeType.prototype.onNodeCreated = function() {
|
||||||
|
const r = origOnNodeCreated ? origOnNodeCreated.apply(this) : void 0;
|
||||||
|
if (!app2.configuringGraph && this.widgets) {
|
||||||
|
for (const w of this.widgets) {
|
||||||
|
if (w?.options?.forceInput || w?.options?.defaultInput) {
|
||||||
|
const config = getConfig.call(this, w.name) ?? [
|
||||||
|
w.type,
|
||||||
|
w.options || {}
|
||||||
|
];
|
||||||
|
convertToInput(this, w, config);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
};
|
||||||
|
const origOnConfigure = nodeType.prototype.onConfigure;
|
||||||
|
nodeType.prototype.onConfigure = function() {
|
||||||
|
const r = origOnConfigure ? origOnConfigure.apply(this, arguments) : void 0;
|
||||||
|
if (!app2.configuringGraph && this.inputs) {
|
||||||
|
for (const input of this.inputs) {
|
||||||
|
if (input.widget && !input.widget[GET_CONFIG]) {
|
||||||
|
input.widget[GET_CONFIG] = () => getConfig.call(this, input.widget.name);
|
||||||
|
const w = this.widgets.find((w2) => w2.name === input.widget.name);
|
||||||
|
if (w) {
|
||||||
|
hideWidget(this, w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
};
|
||||||
|
function isNodeAtPos(pos) {
|
||||||
|
for (const n of app2.graph.nodes) {
|
||||||
|
if (n.pos[0] === pos[0] && n.pos[1] === pos[1]) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
__name(isNodeAtPos, "isNodeAtPos");
|
||||||
|
const origOnInputDblClick = nodeType.prototype.onInputDblClick;
|
||||||
|
const ignoreDblClick = Symbol();
|
||||||
|
nodeType.prototype.onInputDblClick = function(slot) {
|
||||||
|
const r = origOnInputDblClick ? origOnInputDblClick.apply(this, arguments) : void 0;
|
||||||
|
const input = this.inputs[slot];
|
||||||
|
if (!input.widget || !input[ignoreDblClick]) {
|
||||||
|
if (!(input.type in ComfyWidgets) && !(input.widget[GET_CONFIG]?.()?.[0] instanceof Array)) {
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const node = LiteGraph.createNode("PrimitiveNode");
|
||||||
|
app2.graph.add(node);
|
||||||
|
const pos = [
|
||||||
|
this.pos[0] - node.size[0] - 30,
|
||||||
|
this.pos[1]
|
||||||
|
];
|
||||||
|
while (isNodeAtPos(pos)) {
|
||||||
|
pos[1] += LiteGraph.NODE_TITLE_HEIGHT;
|
||||||
|
}
|
||||||
|
node.pos = pos;
|
||||||
|
node.connect(0, this, slot);
|
||||||
|
node.title = input.name;
|
||||||
|
input[ignoreDblClick] = true;
|
||||||
|
setTimeout(() => {
|
||||||
|
delete input[ignoreDblClick];
|
||||||
|
}, 300);
|
||||||
|
return r;
|
||||||
|
};
|
||||||
|
const onConnectInput = nodeType.prototype.onConnectInput;
|
||||||
|
nodeType.prototype.onConnectInput = function(targetSlot, type, output, originNode, originSlot) {
|
||||||
|
const v = onConnectInput?.(this, arguments);
|
||||||
|
if (type !== "COMBO") return v;
|
||||||
|
if (originNode.outputs[originSlot].widget) return v;
|
||||||
|
const targetCombo = this.inputs[targetSlot].widget?.[GET_CONFIG]?.()?.[0];
|
||||||
|
if (!targetCombo || !(targetCombo instanceof Array)) return v;
|
||||||
|
const originConfig = originNode.constructor?.nodeData?.output?.[originSlot];
|
||||||
|
if (!originConfig || !isValidCombo(targetCombo, originConfig)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
};
|
||||||
|
},
|
||||||
|
registerCustomNodes() {
|
||||||
|
LiteGraph.registerNodeType(
|
||||||
|
"PrimitiveNode",
|
||||||
|
Object.assign(PrimitiveNode, {
|
||||||
|
title: "Primitive"
|
||||||
|
})
|
||||||
|
);
|
||||||
|
PrimitiveNode.category = "utils";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
window.comfyAPI = window.comfyAPI || {};
|
||||||
|
window.comfyAPI.widgetInputs = window.comfyAPI.widgetInputs || {};
|
||||||
|
window.comfyAPI.widgetInputs.getWidgetConfig = getWidgetConfig;
|
||||||
|
window.comfyAPI.widgetInputs.convertToInput = convertToInput;
|
||||||
|
window.comfyAPI.widgetInputs.setWidgetConfig = setWidgetConfig;
|
||||||
|
window.comfyAPI.widgetInputs.mergeIfValid = mergeIfValid;
|
||||||
|
export {
|
||||||
|
convertToInput,
|
||||||
|
getWidgetConfig,
|
||||||
|
mergeIfValid,
|
||||||
|
setWidgetConfig
|
||||||
|
};
|
||||||
|
//# sourceMappingURL=widgetInputs-DdoWwzg5.js.map
|
||||||
1
web/assets/widgetInputs-DdoWwzg5.js.map
generated
vendored
Normal file
1
web/assets/widgetInputs-DdoWwzg5.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
2
web/extensions/core/clipspace.js
vendored
2
web/extensions/core/clipspace.js
vendored
@@ -1,2 +1,2 @@
|
|||||||
// Shim for extensions\core\clipspace.ts
|
// Shim for extensions/core/clipspace.ts
|
||||||
export const ClipspaceDialog = window.comfyAPI.clipspace.ClipspaceDialog;
|
export const ClipspaceDialog = window.comfyAPI.clipspace.ClipspaceDialog;
|
||||||
|
|||||||
3
web/extensions/core/colorPalette.js
vendored
Normal file
3
web/extensions/core/colorPalette.js
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
// Shim for extensions/core/colorPalette.ts
|
||||||
|
export const defaultColorPalette = window.comfyAPI.colorPalette.defaultColorPalette;
|
||||||
|
export const getColorPalette = window.comfyAPI.colorPalette.getColorPalette;
|
||||||
2
web/extensions/core/groupNode.js
vendored
2
web/extensions/core/groupNode.js
vendored
@@ -1,3 +1,3 @@
|
|||||||
// Shim for extensions\core\groupNode.ts
|
// Shim for extensions/core/groupNode.ts
|
||||||
export const GroupNodeConfig = window.comfyAPI.groupNode.GroupNodeConfig;
|
export const GroupNodeConfig = window.comfyAPI.groupNode.GroupNodeConfig;
|
||||||
export const GroupNodeHandler = window.comfyAPI.groupNode.GroupNodeHandler;
|
export const GroupNodeHandler = window.comfyAPI.groupNode.GroupNodeHandler;
|
||||||
|
|||||||
2
web/extensions/core/groupNodeManage.js
vendored
2
web/extensions/core/groupNodeManage.js
vendored
@@ -1,2 +1,2 @@
|
|||||||
// Shim for extensions\core\groupNodeManage.ts
|
// Shim for extensions/core/groupNodeManage.ts
|
||||||
export const ManageGroupDialog = window.comfyAPI.groupNodeManage.ManageGroupDialog;
|
export const ManageGroupDialog = window.comfyAPI.groupNodeManage.ManageGroupDialog;
|
||||||
|
|||||||
3
web/extensions/core/widgetInputs.js
vendored
3
web/extensions/core/widgetInputs.js
vendored
@@ -1,4 +1,5 @@
|
|||||||
// Shim for extensions\core\widgetInputs.ts
|
// Shim for extensions/core/widgetInputs.ts
|
||||||
export const getWidgetConfig = window.comfyAPI.widgetInputs.getWidgetConfig;
|
export const getWidgetConfig = window.comfyAPI.widgetInputs.getWidgetConfig;
|
||||||
|
export const convertToInput = window.comfyAPI.widgetInputs.convertToInput;
|
||||||
export const setWidgetConfig = window.comfyAPI.widgetInputs.setWidgetConfig;
|
export const setWidgetConfig = window.comfyAPI.widgetInputs.setWidgetConfig;
|
||||||
export const mergeIfValid = window.comfyAPI.widgetInputs.mergeIfValid;
|
export const mergeIfValid = window.comfyAPI.widgetInputs.mergeIfValid;
|
||||||
|
|||||||
92
web/index.html
vendored
92
web/index.html
vendored
@@ -1,50 +1,42 @@
|
|||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html lang="en">
|
<html lang="en">
|
||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8">
|
<meta charset="UTF-8">
|
||||||
<title>ComfyUI</title>
|
<title>ComfyUI</title>
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
|
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
|
||||||
<!-- Browser Test Fonts -->
|
<link rel="stylesheet" type="text/css" href="user.css" />
|
||||||
<!-- <link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet">
|
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
|
||||||
<link href="https://fonts.googleapis.com/css2?family=Noto+Color+Emoji&family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet">
|
<script type="module" crossorigin src="./assets/index-DGAbdBYF.js"></script>
|
||||||
<style>
|
<link rel="stylesheet" crossorigin href="./assets/index-BHJGjcJh.css">
|
||||||
* {
|
</head>
|
||||||
font-family: 'Roboto Mono', 'Noto Color Emoji';
|
<body class="litegraph grid">
|
||||||
}
|
<div id="vue-app"></div>
|
||||||
</style> -->
|
<div id="comfy-user-selection" class="comfy-user-selection" style="display: none;">
|
||||||
|
<main class="comfy-user-selection-inner">
|
||||||
|
<h1>ComfyUI</h1>
|
||||||
|
<form>
|
||||||
<link rel="stylesheet" type="text/css" href="user.css" />
|
<section>
|
||||||
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
|
<label>New user:
|
||||||
<script type="module" crossorigin src="./assets/index-CI3N807S.js"></script>
|
<input placeholder="Enter a username" />
|
||||||
<link rel="stylesheet" crossorigin href="./assets/index-_5czGnTA.css">
|
</label>
|
||||||
</head>
|
</section>
|
||||||
<body class="litegraph">
|
<div class="comfy-user-existing">
|
||||||
<div id="vue-app"></div>
|
<span class="or-separator">OR</span>
|
||||||
<div id="comfy-user-selection" class="comfy-user-selection" style="display: none;">
|
<section>
|
||||||
<main class="comfy-user-selection-inner">
|
<label>
|
||||||
<h1>ComfyUI</h1>
|
Existing user:
|
||||||
<form>
|
<select>
|
||||||
<section>
|
<option hidden disabled selected value> Select a user </option>
|
||||||
<label>New user:
|
</select>
|
||||||
<input placeholder="Enter a username" />
|
</label>
|
||||||
</label>
|
</section>
|
||||||
</section>
|
</div>
|
||||||
<div class="comfy-user-existing">
|
<footer>
|
||||||
<span class="or-separator">OR</span>
|
<span class="comfy-user-error"> </span>
|
||||||
<section>
|
<button class="comfy-btn comfy-user-button-next">Next</button>
|
||||||
<label>
|
</footer>
|
||||||
Existing user:
|
</form>
|
||||||
<select>
|
</main>
|
||||||
<option hidden disabled selected value> Select a user </option>
|
</div>
|
||||||
</select>
|
</body>
|
||||||
</label>
|
</html>
|
||||||
</section>
|
|
||||||
</div>
|
|
||||||
<footer>
|
|
||||||
<span class="comfy-user-error"> </span>
|
|
||||||
<button class="comfy-btn comfy-user-button-next">Next</button>
|
|
||||||
</footer>
|
|
||||||
</form>
|
|
||||||
</main>
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user