Compare commits
112 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f9d5a244b | ||
|
|
14eba07acd | ||
|
|
4b2f0d9413 | ||
|
|
25eac1d780 | ||
|
|
e38c94228b | ||
|
|
203942c8b2 | ||
|
|
3c72c89a52 | ||
|
|
614377abd6 | ||
|
|
8dfa0cc552 | ||
|
|
e5ecdfdd2d | ||
|
|
7d29fbf74b | ||
|
|
2c641e64ad | ||
|
|
7d2467e830 | ||
|
|
6f021d8aa0 | ||
|
|
d854ed0bcf | ||
|
|
abcd006b8c | ||
|
|
d985d1d7dc | ||
|
|
d1cdf51e1b | ||
|
|
b4626ab93e | ||
|
|
a9e459c2a4 | ||
|
|
3bb4dec720 | ||
|
|
8733191563 | ||
|
|
83b01f960a | ||
|
|
d72e871cfa | ||
|
|
037c3159b6 | ||
|
|
bdd4a22a2e | ||
|
|
fdf37566ef | ||
|
|
08c8968482 | ||
|
|
479a427a48 | ||
|
|
3a0eeee320 | ||
|
|
447da7ea86 | ||
|
|
9c41bc8d10 | ||
|
|
6ad0ddbae4 | ||
|
|
a55142f904 | ||
|
|
5718ef69bb | ||
|
|
13ecf10a92 | ||
|
|
7a415f47a9 | ||
|
|
89fa2fca24 | ||
|
|
364b69e931 | ||
|
|
dc96a1ae19 | ||
|
|
2d810b081e | ||
|
|
9f7e9f0547 | ||
|
|
a355f38ecc | ||
|
|
38c69080c7 | ||
|
|
70a708d726 | ||
|
|
e7d4782736 | ||
|
|
3326bdfd4e | ||
|
|
68bb885d22 | ||
|
|
ad66f7c7d8 | ||
|
|
de8e8e3b0d | ||
|
|
a1e71cfad1 | ||
|
|
0bfc7cc998 | ||
|
|
7183fd1665 | ||
|
|
254838f23c | ||
|
|
0b7dfa986d | ||
|
|
d514bb38ee | ||
|
|
0849c80e2a | ||
|
|
56e8f5e4fd | ||
|
|
e813abbb2c | ||
|
|
5e68a4ce67 | ||
|
|
ca08597670 | ||
|
|
f48e390032 | ||
|
|
369a6dd2c4 | ||
|
|
b3ce8fb9fd | ||
|
|
cf80d28689 | ||
|
|
6fb44c4b7c | ||
|
|
d2247c1e61 | ||
|
|
cb12ad7049 | ||
|
|
f6b7194f64 | ||
|
|
7c6eb4fb29 | ||
|
|
b962db9952 | ||
|
|
d0b7ab88ba | ||
|
|
405b529545 | ||
|
|
9d720187f1 | ||
|
|
d247bc5a9c | ||
|
|
9f4daca9d9 | ||
|
|
b5d0f2a908 | ||
|
|
e760bf5c40 | ||
|
|
36c83cdbba | ||
|
|
81778a7feb | ||
|
|
bc94662b31 | ||
|
|
9fa8faa44a | ||
|
|
9a7444e39f | ||
|
|
54fca4a218 | ||
|
|
cd4955367e | ||
|
|
8354203d95 | ||
|
|
e0b41243b4 | ||
|
|
619263d4a6 | ||
|
|
e3b0402bb7 | ||
|
|
967867d48c | ||
|
|
cbaac71bf5 | ||
|
|
3ab3516e46 | ||
|
|
9c5fca75f4 | ||
|
|
a5da4d0b3e | ||
|
|
32a60a7bac | ||
|
|
bb52934ba4 | ||
|
|
8aabd7c8c0 | ||
|
|
a09b29ca11 | ||
|
|
9bfee68773 | ||
|
|
ea77750759 | ||
|
|
c27ebeb1c2 | ||
|
|
0c7c98a965 | ||
|
|
dc2eb75b85 | ||
|
|
fa34efe3bd | ||
|
|
5cbaa9e07c | ||
|
|
c7427375ee | ||
|
|
22d1241a50 | ||
|
|
f04229b84d | ||
|
|
f067ad15d1 | ||
|
|
483004dd1d | ||
|
|
00a5d08103 | ||
|
|
d043997d30 |
2
.github/workflows/pullrequest-ci-run.yml
vendored
2
.github/workflows/pullrequest-ci-run.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
|
||||
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@@ -12,7 +12,7 @@ on:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "121"
|
||||
default: "124"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
|
||||
4
.github/workflows/test-ci.yml
vendored
4
.github/workflows/test-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
runner_label: [self-hosted, Windows]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
|
||||
30
.github/workflows/test-unit.yml
vendored
Normal file
30
.github/workflows/test-unit.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: Unit Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
- name: Run Unit Tests
|
||||
run: |
|
||||
pip install -r tests-unit/requirements.txt
|
||||
python -m pytest tests-unit
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,6 +12,7 @@ extra_model_paths.yaml
|
||||
.vscode/
|
||||
.idea/
|
||||
venv/
|
||||
.venv/
|
||||
/web/extensions/*
|
||||
!/web/extensions/logging.js.example
|
||||
!/web/extensions/core/
|
||||
|
||||
@@ -94,6 +94,8 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
| Alt + `+` | Canvas Zoom in |
|
||||
| Alt + `-` | Canvas Zoom out |
|
||||
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
||||
| P | Pin/Unpin selected nodes |
|
||||
| Ctrl + G | Group selected nodes |
|
||||
| Q | Toggle visibility of the queue |
|
||||
| H | Toggle visibility of history |
|
||||
| R | Refresh graph |
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from aiohttp import web
|
||||
from typing import Optional
|
||||
from folder_paths import models_dir, user_directory, output_directory
|
||||
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
||||
from api_server.services.file_service import FileService
|
||||
import app.logger
|
||||
|
||||
@@ -36,6 +36,13 @@ class InternalRoutes:
|
||||
async def get_logs(request):
|
||||
return web.json_response(app.logger.get_logs())
|
||||
|
||||
@self.routes.get('/folder_paths')
|
||||
async def get_folder_paths(request):
|
||||
response = {}
|
||||
for key in folder_names_and_paths:
|
||||
response[key] = folder_names_and_paths[key][0]
|
||||
return web.json_response(response)
|
||||
|
||||
def get_app(self):
|
||||
if self._app is None:
|
||||
self._app = web.Application()
|
||||
|
||||
@@ -10,14 +10,14 @@ def get_logs():
|
||||
return "\n".join([formatter.format(x) for x in logs])
|
||||
|
||||
|
||||
def setup_logger(verbose: bool = False, capacity: int = 300):
|
||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||
global logs
|
||||
if logs:
|
||||
return
|
||||
|
||||
# Setup default global logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
|
||||
@@ -5,17 +5,17 @@ import uuid
|
||||
import glob
|
||||
import shutil
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
from folder_paths import user_directory
|
||||
import folder_paths
|
||||
from .app_settings import AppSettings
|
||||
|
||||
default_user = "default"
|
||||
users_file = os.path.join(user_directory, "users.json")
|
||||
|
||||
|
||||
class UserManager():
|
||||
def __init__(self):
|
||||
global user_directory
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
||||
self.settings = AppSettings(self)
|
||||
if not os.path.exists(user_directory):
|
||||
@@ -25,14 +25,17 @@ class UserManager():
|
||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||
|
||||
if args.multi_user:
|
||||
if os.path.isfile(users_file):
|
||||
with open(users_file) as f:
|
||||
if os.path.isfile(self.get_users_file()):
|
||||
with open(self.get_users_file()) as f:
|
||||
self.users = json.load(f)
|
||||
else:
|
||||
self.users = {}
|
||||
else:
|
||||
self.users = {"default": "default"}
|
||||
|
||||
def get_users_file(self):
|
||||
return os.path.join(folder_paths.get_user_directory(), "users.json")
|
||||
|
||||
def get_request_user_id(self, request):
|
||||
user = "default"
|
||||
if args.multi_user and "comfy-user" in request.headers:
|
||||
@@ -44,7 +47,7 @@ class UserManager():
|
||||
return user
|
||||
|
||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||
global user_directory
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
||||
if type == "userdata":
|
||||
root_dir = user_directory
|
||||
@@ -59,6 +62,10 @@ class UserManager():
|
||||
return None
|
||||
|
||||
if file is not None:
|
||||
# Check if filename is url encoded
|
||||
if "%" in file:
|
||||
file = parse.unquote(file)
|
||||
|
||||
# prevent leaving /{type}/{user}
|
||||
path = os.path.abspath(os.path.join(user_root, file))
|
||||
if os.path.commonpath((user_root, path)) != user_root:
|
||||
@@ -80,8 +87,7 @@ class UserManager():
|
||||
|
||||
self.users[user_id] = name
|
||||
|
||||
global users_file
|
||||
with open(users_file, "w") as f:
|
||||
with open(self.get_users_file(), "w") as f:
|
||||
json.dump(self.users, f)
|
||||
|
||||
return user_id
|
||||
@@ -112,25 +118,69 @@ class UserManager():
|
||||
|
||||
@routes.get("/userdata")
|
||||
async def listuserdata(request):
|
||||
"""
|
||||
List user data files in a specified directory.
|
||||
|
||||
This endpoint allows listing files in a user's data directory, with options for recursion,
|
||||
full file information, and path splitting.
|
||||
|
||||
Query Parameters:
|
||||
- dir (required): The directory to list files from.
|
||||
- recurse (optional): If "true", recursively list files in subdirectories.
|
||||
- full_info (optional): If "true", return detailed file information (path, size, modified time).
|
||||
- split (optional): If "true", split file paths into components (only applies when full_info is false).
|
||||
|
||||
Returns:
|
||||
- 400: If 'dir' parameter is missing.
|
||||
- 403: If the requested path is not allowed.
|
||||
- 404: If the requested directory does not exist.
|
||||
- 200: JSON response with the list of files or file information.
|
||||
|
||||
The response format depends on the query parameters:
|
||||
- Default: List of relative file paths.
|
||||
- full_info=true: List of dictionaries with file details.
|
||||
- split=true (and full_info=false): List of lists, each containing path components.
|
||||
"""
|
||||
directory = request.rel_url.query.get('dir', '')
|
||||
if not directory:
|
||||
return web.Response(status=400)
|
||||
|
||||
return web.Response(status=400, text="Directory not provided")
|
||||
|
||||
path = self.get_request_user_filepath(request, directory)
|
||||
if not path:
|
||||
return web.Response(status=403)
|
||||
|
||||
return web.Response(status=403, text="Invalid directory")
|
||||
|
||||
if not os.path.exists(path):
|
||||
return web.Response(status=404)
|
||||
|
||||
return web.Response(status=404, text="Directory not found")
|
||||
|
||||
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
||||
results = glob.glob(os.path.join(
|
||||
glob.escape(path), '**/*'), recursive=recurse)
|
||||
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
|
||||
|
||||
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
|
||||
|
||||
# Use different patterns based on whether we're recursing or not
|
||||
if recurse:
|
||||
pattern = os.path.join(glob.escape(path), '**', '*')
|
||||
else:
|
||||
pattern = os.path.join(glob.escape(path), '*')
|
||||
|
||||
results = glob.glob(pattern, recursive=recurse)
|
||||
|
||||
if full_info:
|
||||
results = [
|
||||
{
|
||||
'path': os.path.relpath(x, path).replace(os.sep, '/'),
|
||||
'size': os.path.getsize(x),
|
||||
'modified': os.path.getmtime(x)
|
||||
} for x in results if os.path.isfile(x)
|
||||
]
|
||||
else:
|
||||
results = [
|
||||
os.path.relpath(x, path).replace(os.sep, '/')
|
||||
for x in results
|
||||
if os.path.isfile(x)
|
||||
]
|
||||
|
||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||
if split_path:
|
||||
results = [[x] + x.split(os.sep) for x in results]
|
||||
if split_path and not full_info:
|
||||
results = [[x] + x.split('/') for x in results]
|
||||
|
||||
return web.json_response(results)
|
||||
|
||||
@@ -138,14 +188,14 @@ class UserManager():
|
||||
file = request.match_info.get(param, None)
|
||||
if not file:
|
||||
return web.Response(status=400)
|
||||
|
||||
|
||||
path = self.get_request_user_filepath(request, file)
|
||||
if not path:
|
||||
return web.Response(status=403)
|
||||
|
||||
|
||||
if check_exists and not os.path.exists(path):
|
||||
return web.Response(status=404)
|
||||
|
||||
|
||||
return path
|
||||
|
||||
@routes.get("/userdata/{file}")
|
||||
@@ -153,7 +203,7 @@ class UserManager():
|
||||
path = get_user_data_path(request, check_exists=True)
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
|
||||
return web.FileResponse(path)
|
||||
|
||||
@routes.post("/userdata/{file}")
|
||||
@@ -161,7 +211,7 @@ class UserManager():
|
||||
path = get_user_data_path(request)
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
if not overwrite and os.path.exists(path):
|
||||
return web.Response(status=409)
|
||||
@@ -170,7 +220,7 @@ class UserManager():
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
|
||||
|
||||
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
|
||||
return web.json_response(resp)
|
||||
|
||||
@@ -181,7 +231,7 @@ class UserManager():
|
||||
return path
|
||||
|
||||
os.remove(path)
|
||||
|
||||
|
||||
return web.Response(status=204)
|
||||
|
||||
@routes.post("/userdata/{file}/move/{dest}")
|
||||
@@ -189,17 +239,17 @@ class UserManager():
|
||||
source = get_user_data_path(request, check_exists=True)
|
||||
if not isinstance(source, str):
|
||||
return source
|
||||
|
||||
|
||||
dest = get_user_data_path(request, check_exists=False, param="dest")
|
||||
if not isinstance(source, str):
|
||||
return dest
|
||||
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
if not overwrite and os.path.exists(dest):
|
||||
return web.Response(status=409)
|
||||
|
||||
print(f"moving '{source}' -> '{dest}'")
|
||||
shutil.move(source, dest)
|
||||
|
||||
|
||||
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
|
||||
return web.json_response(resp)
|
||||
|
||||
@@ -6,6 +6,7 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks = None,
|
||||
control_latent_channels = None,
|
||||
dtype = None,
|
||||
device = None,
|
||||
operations = None,
|
||||
@@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||
for _ in range(len(self.joint_blocks)):
|
||||
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
||||
|
||||
if control_latent_channels is None:
|
||||
control_latent_channels = self.in_channels
|
||||
|
||||
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
||||
None,
|
||||
self.patch_size,
|
||||
self.in_channels,
|
||||
control_latent_channels,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
strict_img_size=False,
|
||||
|
||||
@@ -36,7 +36,7 @@ class EnumAction(argparse.Action):
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
||||
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
||||
@@ -92,6 +92,8 @@ class LatentPreviewMethod(enum.Enum):
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
@@ -134,7 +136,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
@@ -169,6 +171,8 @@ parser.add_argument(
|
||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||
)
|
||||
|
||||
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
|
||||
@@ -109,8 +109,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
if k not in u:
|
||||
t = sd.pop(k)
|
||||
del t
|
||||
sd.pop(k)
|
||||
return clip
|
||||
|
||||
def load(ckpt_path):
|
||||
|
||||
@@ -79,13 +79,21 @@ class ControlBase:
|
||||
self.previous_controlnet = None
|
||||
self.extra_conds = []
|
||||
self.strength_type = StrengthType.CONSTANT
|
||||
self.concat_mask = False
|
||||
self.extra_concat_orig = []
|
||||
self.extra_concat = None
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||
self.cond_hint_original = cond_hint
|
||||
self.strength = strength
|
||||
self.timestep_percent_range = timestep_percent_range
|
||||
if self.latent_format is not None:
|
||||
if vae is None:
|
||||
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
|
||||
self.vae = vae
|
||||
self.extra_concat_orig = extra_concat.copy()
|
||||
if self.concat_mask and len(self.extra_concat_orig) == 0:
|
||||
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
|
||||
return self
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
@@ -100,9 +108,9 @@ class ControlBase:
|
||||
def cleanup(self):
|
||||
if self.previous_controlnet is not None:
|
||||
self.previous_controlnet.cleanup()
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
|
||||
self.cond_hint = None
|
||||
self.extra_concat = None
|
||||
self.timestep_range = None
|
||||
|
||||
def get_models(self):
|
||||
@@ -123,6 +131,8 @@ class ControlBase:
|
||||
c.vae = self.vae
|
||||
c.extra_conds = self.extra_conds.copy()
|
||||
c.strength_type = self.strength_type
|
||||
c.concat_mask = self.concat_mask
|
||||
c.extra_concat_orig = self.extra_concat_orig.copy()
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
if self.previous_controlnet is not None:
|
||||
@@ -175,7 +185,7 @@ class ControlBase:
|
||||
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
|
||||
super().__init__(device)
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
@@ -189,6 +199,7 @@ class ControlNet(ControlBase):
|
||||
self.latent_format = latent_format
|
||||
self.extra_conds += extra_conds
|
||||
self.strength_type = strength_type
|
||||
self.concat_mask = concat_mask
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
control_prev = None
|
||||
@@ -213,6 +224,9 @@ class ControlNet(ControlBase):
|
||||
compression_ratio = self.compression_ratio
|
||||
if self.vae is not None:
|
||||
compression_ratio *= self.vae.downscale_ratio
|
||||
else:
|
||||
if self.latent_format is not None:
|
||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
if self.vae is not None:
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
@@ -220,6 +234,13 @@ class ControlNet(ControlBase):
|
||||
comfy.model_management.load_models_gpu(loaded_models)
|
||||
if self.latent_format is not None:
|
||||
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
||||
if len(self.extra_concat_orig) > 0:
|
||||
to_concat = []
|
||||
for c in self.extra_concat_orig:
|
||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
||||
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||
|
||||
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
@@ -319,7 +340,7 @@ class ControlLoraOps:
|
||||
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
||||
def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options
|
||||
ControlBase.__init__(self, device)
|
||||
self.control_weights = control_weights
|
||||
self.global_average_pooling = global_average_pooling
|
||||
@@ -376,19 +397,25 @@ class ControlLora(ControlNet):
|
||||
def inference_memory_requirements(self, dtype):
|
||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||
|
||||
def controlnet_config(sd):
|
||||
def controlnet_config(sd, model_options={}):
|
||||
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
||||
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
unet_dtype = model_options.get("dtype", None)
|
||||
if unet_dtype is None:
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None:
|
||||
supported_inference_dtypes.append(weight_dtype)
|
||||
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||
|
||||
controlnet_config = model_config.unet_config
|
||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
operations = comfy.ops.manual_cast
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
offload_device = comfy.model_management.unet_offload_device()
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||
@@ -403,24 +430,29 @@ def controlnet_load_state_dict(control_model, sd):
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
return control_model
|
||||
|
||||
def load_controlnet_mmdit(sd):
|
||||
def load_controlnet_mmdit(sd, model_options={}):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
concat_mask = False
|
||||
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
|
||||
if control_latent_channels == 17: #inpaint controlnet
|
||||
concat_mask = True
|
||||
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
|
||||
latent_format = comfy.latent_formats.SD3()
|
||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_hunyuandit(controlnet_data):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
|
||||
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
||||
|
||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||
@@ -430,17 +462,17 @@ def load_controlnet_hunyuandit(controlnet_data):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||
return control
|
||||
|
||||
def load_controlnet_flux_xlabs(sd):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def load_controlnet_flux_instantx(sd):
|
||||
def load_controlnet_flux_instantx(sd, model_options={}):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
@@ -449,21 +481,30 @@ def load_controlnet_flux_instantx(sd):
|
||||
if union_cnet in new_sd:
|
||||
num_union_modes = new_sd[union_cnet].shape[0]
|
||||
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
|
||||
concat_mask = False
|
||||
if control_latent_channels == 17:
|
||||
concat_mask = True
|
||||
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
|
||||
latent_format = comfy.latent_formats.Flux()
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
|
||||
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
controlnet_data = state_dict
|
||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||
return load_controlnet_hunyuandit(controlnet_data)
|
||||
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
|
||||
|
||||
if "lora_controlnet" in controlnet_data:
|
||||
return ControlLora(controlnet_data)
|
||||
return ControlLora(controlnet_data, model_options=model_options)
|
||||
|
||||
controlnet_config = None
|
||||
supported_inference_dtypes = None
|
||||
@@ -518,13 +559,15 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if len(leftover_keys) > 0:
|
||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||
controlnet_data = new_sd
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data:
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||
return load_controlnet_flux_xlabs(controlnet_data)
|
||||
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
||||
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||
return load_controlnet_mmdit(controlnet_data)
|
||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data)
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||
pth = False
|
||||
@@ -536,25 +579,36 @@ def load_controlnet(ckpt_path, model=None):
|
||||
elif key in controlnet_data:
|
||||
prefix = ""
|
||||
else:
|
||||
net = load_t2i_adapter(controlnet_data)
|
||||
net = load_t2i_adapter(controlnet_data, model_options=model_options)
|
||||
if net is None:
|
||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||
logging.error("error could not detect control model type.")
|
||||
return net
|
||||
|
||||
if controlnet_config is None:
|
||||
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||
controlnet_config = model_config.unet_config
|
||||
|
||||
unet_dtype = model_options.get("dtype", None)
|
||||
if unet_dtype is None:
|
||||
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
|
||||
|
||||
if supported_inference_dtypes is None:
|
||||
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
|
||||
|
||||
if weight_dtype is not None:
|
||||
supported_inference_dtypes.append(weight_dtype)
|
||||
|
||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
if supported_inference_dtypes is None:
|
||||
unet_dtype = comfy.model_management.unet_dtype()
|
||||
else:
|
||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
|
||||
|
||||
controlnet_config["operations"] = operations
|
||||
controlnet_config["dtype"] = unet_dtype
|
||||
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
||||
controlnet_config.pop("out_channels")
|
||||
@@ -590,14 +644,21 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if len(unexpected) > 0:
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
|
||||
global_average_pooling = False
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
global_average_pooling = model_options.get("global_average_pooling", False)
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||
if "global_average_pooling" not in model_options:
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||
model_options["global_average_pooling"] = True
|
||||
|
||||
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
|
||||
if cnet is None:
|
||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||
return cnet
|
||||
|
||||
class T2IAdapter(ControlBase):
|
||||
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
||||
super().__init__(device)
|
||||
@@ -653,7 +714,7 @@ class T2IAdapter(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def load_t2i_adapter(t2i_data):
|
||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
compression_ratio = 8
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||
mantissa_scaled = torch.where(
|
||||
@@ -41,9 +40,8 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
||||
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||
)
|
||||
del abs_x
|
||||
|
||||
return sign.to(dtype=dtype)
|
||||
return sign
|
||||
|
||||
|
||||
|
||||
@@ -57,6 +55,11 @@ def stochastic_rounding(value, dtype, seed=0):
|
||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||
generator = torch.Generator(device=value.device)
|
||||
generator.manual_seed(seed)
|
||||
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
|
||||
output = torch.empty_like(value, dtype=dtype)
|
||||
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||
for i in range(0, value.shape[0], slice_size):
|
||||
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
||||
return output
|
||||
|
||||
return value.to(dtype=dtype)
|
||||
|
||||
@@ -44,6 +44,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
|
||||
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
|
||||
epsilon = 1e-5 # avoid log(0)
|
||||
x = torch.linspace(0, 1, n, device=device)
|
||||
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
|
||||
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
|
||||
sigmas = clamp(torch.exp(lmb))
|
||||
return sigmas
|
||||
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
||||
@@ -1101,3 +1112,78 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
|
||||
temp = [0]
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], temp[0])
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = denoised + d * sigma_down
|
||||
else:
|
||||
# DPM-Solver++(2S)
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
||||
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
|
||||
r = 1 / 2
|
||||
h = t_next - t
|
||||
s = t + r * h
|
||||
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""DPM-Solver++(2M)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
old_uncond_denoised = None
|
||||
uncond_denoised = None
|
||||
def post_cfg_function(args):
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
if old_uncond_denoised is None or sigmas[i + 1] == 0:
|
||||
denoised_mix = -torch.exp(-h) * uncond_denoised
|
||||
else:
|
||||
h_last = t - t_fn(sigmas[i - 1])
|
||||
r = h_last / h
|
||||
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
|
||||
x = denoised + denoised_mix + torch.exp(-h) * x
|
||||
old_uncond_denoised = uncond_denoised
|
||||
return x
|
||||
@@ -4,6 +4,7 @@ class LatentFormat:
|
||||
scale_factor = 1.0
|
||||
latent_channels = 4
|
||||
latent_rgb_factors = None
|
||||
latent_rgb_factors_bias = None
|
||||
taesd_decoder_name = None
|
||||
|
||||
def process_in(self, latent):
|
||||
@@ -30,11 +31,13 @@ class SDXL(LatentFormat):
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
# R G B
|
||||
[ 0.3920, 0.4054, 0.4549],
|
||||
[-0.2634, -0.0196, 0.0653],
|
||||
[ 0.0568, 0.1687, -0.0755],
|
||||
[-0.3112, -0.2359, -0.2076]
|
||||
[ 0.3651, 0.4232, 0.4341],
|
||||
[-0.2533, -0.0042, 0.1068],
|
||||
[ 0.1076, 0.1111, -0.0362],
|
||||
[-0.3165, -0.2492, -0.2188]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011]
|
||||
|
||||
self.taesd_decoder_name = "taesdxl_decoder"
|
||||
|
||||
class SDXL_Playground_2_5(LatentFormat):
|
||||
@@ -112,23 +115,24 @@ class SD3(LatentFormat):
|
||||
self.scale_factor = 1.5305
|
||||
self.shift_factor = 0.0609
|
||||
self.latent_rgb_factors = [
|
||||
[-0.0645, 0.0177, 0.1052],
|
||||
[ 0.0028, 0.0312, 0.0650],
|
||||
[ 0.1848, 0.0762, 0.0360],
|
||||
[ 0.0944, 0.0360, 0.0889],
|
||||
[ 0.0897, 0.0506, -0.0364],
|
||||
[-0.0020, 0.1203, 0.0284],
|
||||
[ 0.0855, 0.0118, 0.0283],
|
||||
[-0.0539, 0.0658, 0.1047],
|
||||
[-0.0057, 0.0116, 0.0700],
|
||||
[-0.0412, 0.0281, -0.0039],
|
||||
[ 0.1106, 0.1171, 0.1220],
|
||||
[-0.0248, 0.0682, -0.0481],
|
||||
[ 0.0815, 0.0846, 0.1207],
|
||||
[-0.0120, -0.0055, -0.0867],
|
||||
[-0.0749, -0.0634, -0.0456],
|
||||
[-0.1418, -0.1457, -0.1259]
|
||||
[-0.0922, -0.0175, 0.0749],
|
||||
[ 0.0311, 0.0633, 0.0954],
|
||||
[ 0.1994, 0.0927, 0.0458],
|
||||
[ 0.0856, 0.0339, 0.0902],
|
||||
[ 0.0587, 0.0272, -0.0496],
|
||||
[-0.0006, 0.1104, 0.0309],
|
||||
[ 0.0978, 0.0306, 0.0427],
|
||||
[-0.0042, 0.1038, 0.1358],
|
||||
[-0.0194, 0.0020, 0.0669],
|
||||
[-0.0488, 0.0130, -0.0268],
|
||||
[ 0.0922, 0.0988, 0.0951],
|
||||
[-0.0278, 0.0524, -0.0542],
|
||||
[ 0.0332, 0.0456, 0.0895],
|
||||
[-0.0069, -0.0030, -0.0810],
|
||||
[-0.0596, -0.0465, -0.0293],
|
||||
[-0.1448, -0.1463, -0.1189]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [0.2394, 0.2135, 0.1925]
|
||||
self.taesd_decoder_name = "taesd3_decoder"
|
||||
|
||||
def process_in(self, latent):
|
||||
@@ -146,23 +150,24 @@ class Flux(SD3):
|
||||
self.scale_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
self.latent_rgb_factors =[
|
||||
[-0.0404, 0.0159, 0.0609],
|
||||
[ 0.0043, 0.0298, 0.0850],
|
||||
[ 0.0328, -0.0749, -0.0503],
|
||||
[-0.0245, 0.0085, 0.0549],
|
||||
[ 0.0966, 0.0894, 0.0530],
|
||||
[ 0.0035, 0.0399, 0.0123],
|
||||
[ 0.0583, 0.1184, 0.1262],
|
||||
[-0.0191, -0.0206, -0.0306],
|
||||
[-0.0324, 0.0055, 0.1001],
|
||||
[ 0.0955, 0.0659, -0.0545],
|
||||
[-0.0504, 0.0231, -0.0013],
|
||||
[ 0.0500, -0.0008, -0.0088],
|
||||
[ 0.0982, 0.0941, 0.0976],
|
||||
[-0.1233, -0.0280, -0.0897],
|
||||
[-0.0005, -0.0530, -0.0020],
|
||||
[-0.1273, -0.0932, -0.0680]
|
||||
[-0.0346, 0.0244, 0.0681],
|
||||
[ 0.0034, 0.0210, 0.0687],
|
||||
[ 0.0275, -0.0668, -0.0433],
|
||||
[-0.0174, 0.0160, 0.0617],
|
||||
[ 0.0859, 0.0721, 0.0329],
|
||||
[ 0.0004, 0.0383, 0.0115],
|
||||
[ 0.0405, 0.0861, 0.0915],
|
||||
[-0.0236, -0.0185, -0.0259],
|
||||
[-0.0245, 0.0250, 0.1180],
|
||||
[ 0.1008, 0.0755, -0.0421],
|
||||
[-0.0515, 0.0201, 0.0011],
|
||||
[ 0.0428, -0.0012, -0.0036],
|
||||
[ 0.0817, 0.0765, 0.0749],
|
||||
[-0.1264, -0.0522, -0.1103],
|
||||
[-0.0280, -0.0881, -0.0499],
|
||||
[-0.1262, -0.0982, -0.0778]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||
self.taesd_decoder_name = "taef1_decoder"
|
||||
|
||||
def process_in(self, latent):
|
||||
|
||||
@@ -14,7 +14,7 @@ except:
|
||||
rms_norm_torch = None
|
||||
|
||||
def rms_norm(x, weight, eps=1e-6):
|
||||
if rms_norm_torch is not None:
|
||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
else:
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||
#modified to support different types of flux controlnets
|
||||
|
||||
import torch
|
||||
import math
|
||||
@@ -12,22 +13,65 @@ from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
from .model import Flux
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class MistolineCondDownsamplBlock(nn.Module):
|
||||
def __init__(self, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.encoder = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
class MistolineControlnetBlock(nn.Module):
|
||||
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.linear(x))
|
||||
|
||||
|
||||
class ControlNetFlux(Flux):
|
||||
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||
|
||||
self.main_model_double = 19
|
||||
self.main_model_single = 38
|
||||
|
||||
self.mistoline = mistoline
|
||||
# add ControlNet blocks
|
||||
if self.mistoline:
|
||||
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(self.params.depth):
|
||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
self.controlnet_blocks.append(control_block())
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(self.params.depth_single_blocks):
|
||||
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
|
||||
self.controlnet_single_blocks.append(control_block())
|
||||
|
||||
self.num_union_modes = num_union_modes
|
||||
self.controlnet_mode_embedder = None
|
||||
@@ -36,25 +80,33 @@ class ControlNetFlux(Flux):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.latent_input = latent_input
|
||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
if control_latent_channels is None:
|
||||
control_latent_channels = self.in_channels
|
||||
else:
|
||||
control_latent_channels *= 2 * 2 #patch size
|
||||
|
||||
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
if not self.latent_input:
|
||||
self.input_hint_block = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
if self.mistoline:
|
||||
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.input_hint_block = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
@@ -73,9 +125,6 @@ class ControlNetFlux(Flux):
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
if not self.latent_input:
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
|
||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||
img = img + controlnet_cond
|
||||
@@ -131,9 +180,14 @@ class ControlNetFlux(Flux):
|
||||
patch_size = 2
|
||||
if self.latent_input:
|
||||
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
elif self.mistoline:
|
||||
hint = hint * 2.0 - 1.0
|
||||
hint = self.input_cond_block(hint)
|
||||
else:
|
||||
hint = hint * 2.0 - 1.0
|
||||
hint = self.input_hint_block(hint)
|
||||
|
||||
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
@@ -108,7 +108,7 @@ class Flux(nn.Module):
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y)
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
@@ -151,8 +151,8 @@ class Flux(nn.Module):
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
|
||||
@@ -842,6 +842,11 @@ class UNetModel(nn.Module):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if "emb_patch" in transformer_patches:
|
||||
patch = transformer_patches["emb_patch"]
|
||||
for p in patch:
|
||||
emb = p(emb, self.model_channels, transformer_options)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
@@ -201,9 +201,13 @@ def load_lora(lora, to_load):
|
||||
|
||||
def model_lora_keys_clip(model, key_map={}):
|
||||
sdk = model.state_dict().keys()
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
clip_g_present = False
|
||||
for b in range(32): #TODO: clean up
|
||||
for c in LORA_CLIP_MAP:
|
||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
@@ -227,6 +231,7 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
|
||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
clip_g_present = True
|
||||
if clip_l_present:
|
||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||
key_map[lora_key] = k
|
||||
@@ -242,10 +247,18 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
|
||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
t5_index = 1
|
||||
if clip_g_present:
|
||||
t5_index += 1
|
||||
if clip_l_present:
|
||||
t5_index += 1
|
||||
if t5_index == 2:
|
||||
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
|
||||
t5_index += 1
|
||||
|
||||
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
|
||||
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||
@@ -281,6 +294,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
|
||||
|
||||
diffusers_lora_prefix = ["", "unet."]
|
||||
for p in diffusers_lora_prefix:
|
||||
@@ -324,14 +338,15 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
to = diffusers_keys[k]
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
@@ -400,7 +415,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
||||
v = (calculate_weight(v[1:], comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype, copy=True), key, intermediate_dtype=intermediate_dtype), )
|
||||
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
@@ -438,7 +453,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
try:
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
@@ -484,7 +499,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
@@ -521,28 +536,48 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "glora":
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / v[0].shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
dora_scale = v[5]
|
||||
|
||||
old_glora = False
|
||||
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||
rank = v[0].shape[0]
|
||||
old_glora = True
|
||||
|
||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||
pass
|
||||
else:
|
||||
old_glora = False
|
||||
rank = v[1].shape[0]
|
||||
|
||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / rank
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape)
|
||||
if old_glora:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||
else:
|
||||
if weight.dim() > 2:
|
||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
else:
|
||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
|
||||
@@ -96,7 +96,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=model_config.optimizations.get("fp8", False))
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
|
||||
@@ -145,7 +145,7 @@ total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
|
||||
try:
|
||||
logging.info("pytorch version: {}".format(torch.version.__version__))
|
||||
logging.info("pytorch version: {}".format(torch_version))
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -326,7 +326,7 @@ class LoadedModel:
|
||||
self.model_unload()
|
||||
raise e
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
|
||||
with torch.no_grad():
|
||||
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||
|
||||
@@ -426,7 +426,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded:
|
||||
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
for x in sorted(can_unload):
|
||||
@@ -626,6 +626,8 @@ def maximum_vram_for_weights(device=None):
|
||||
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
||||
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if model_params < 0:
|
||||
model_params = 1000000000000000000000
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if args.fp16_unet:
|
||||
@@ -897,7 +899,7 @@ def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
try:
|
||||
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
|
||||
if (14, 5) <= macos_version <= (15, 0, 1): # black image bug on recent versions of macOS
|
||||
upcast = True
|
||||
except:
|
||||
pass
|
||||
@@ -1063,6 +1065,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
def supports_fp8_compute(device=None):
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 9:
|
||||
return True
|
||||
@@ -1070,6 +1075,14 @@ def supports_fp8_compute(device=None):
|
||||
return False
|
||||
if props.minor < 9:
|
||||
return False
|
||||
|
||||
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
|
||||
return False
|
||||
|
||||
if WINDOWS:
|
||||
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
|
||||
@@ -28,7 +28,7 @@ import comfy.utils
|
||||
import comfy.float
|
||||
import comfy.model_management
|
||||
import comfy.lora
|
||||
from comfy.types import UnetWrapperFunction
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
@@ -88,8 +88,12 @@ class LowVramPatch:
|
||||
self.key = key
|
||||
self.patches = patches
|
||||
def __call__(self, weight):
|
||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||
intermediate_dtype = weight.dtype
|
||||
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||
intermediate_dtype = torch.float32
|
||||
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
|
||||
|
||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
self.size = size
|
||||
@@ -283,17 +287,21 @@ class ModelPatcher:
|
||||
return list(p)
|
||||
|
||||
def get_key_patches(self, filter_prefix=None):
|
||||
comfy.model_management.unload_model_clones(self)
|
||||
model_sd = self.model_state_dict()
|
||||
p = {}
|
||||
for k in model_sd:
|
||||
if filter_prefix is not None:
|
||||
if not k.startswith(filter_prefix):
|
||||
continue
|
||||
if k in self.patches:
|
||||
p[k] = [model_sd[k]] + self.patches[k]
|
||||
bk = self.backup.get(k, None)
|
||||
if bk is not None:
|
||||
weight = bk.weight
|
||||
else:
|
||||
p[k] = (model_sd[k],)
|
||||
weight = model_sd[k]
|
||||
if k in self.patches:
|
||||
p[k] = [weight] + self.patches[k]
|
||||
else:
|
||||
p[k] = (weight,)
|
||||
return p
|
||||
|
||||
def model_state_dict(self, filter_prefix=None):
|
||||
|
||||
@@ -260,7 +260,6 @@ def fp8_linear(self, input):
|
||||
|
||||
if len(input.shape) == 3:
|
||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||
w = w.t()
|
||||
|
||||
@@ -300,10 +299,14 @@ class fp8_ops(manual_cast):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None):
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False):
|
||||
if comfy.model_management.supports_fp8_compute(load_device):
|
||||
if (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
||||
return fp8_ops
|
||||
|
||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||
return disable_weight_init
|
||||
if args.fast:
|
||||
if args.fast and not disable_fast_fp8:
|
||||
if comfy.model_management.supports_fp8_compute(load_device):
|
||||
return fp8_ops
|
||||
return manual_cast
|
||||
|
||||
@@ -6,7 +6,7 @@ from comfy import model_management
|
||||
import math
|
||||
import logging
|
||||
import comfy.sampler_helpers
|
||||
import scipy
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
@@ -570,8 +570,8 @@ class Sampler:
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
|
||||
79
comfy/sd.py
79
comfy/sd.py
@@ -29,7 +29,6 @@ import comfy.text_encoders.long_clipl
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.supported_models_base
|
||||
import comfy.taesd.taesd
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
@@ -70,14 +69,14 @@ class CLIP:
|
||||
clip = target.clip
|
||||
tokenizer = target.tokenizer
|
||||
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
load_device = model_options.get("load_device", model_management.text_encoder_device())
|
||||
offload_device = model_options.get("offload_device", model_management.text_encoder_offload_device())
|
||||
dtype = model_options.get("dtype", None)
|
||||
if dtype is None:
|
||||
dtype = model_management.text_encoder_dtype(load_device)
|
||||
|
||||
params['dtype'] = dtype
|
||||
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
||||
params['device'] = model_options.get("initial_device", model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)))
|
||||
params['model_options'] = model_options
|
||||
|
||||
self.cond_stage_model = clip(**(params))
|
||||
@@ -348,7 +347,7 @@ class VAE:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
@@ -406,8 +405,35 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||
|
||||
|
||||
class TEModel(Enum):
|
||||
CLIP_L = 1
|
||||
CLIP_H = 2
|
||||
CLIP_G = 3
|
||||
T5_XXL = 4
|
||||
T5_XL = 5
|
||||
T5_BASE = 6
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
return TEModel.CLIP_G
|
||||
if "text_model.encoder.layers.22.mlp.fc1.weight" in sd:
|
||||
return TEModel.CLIP_H
|
||||
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
||||
return TEModel.CLIP_L
|
||||
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
if weight.shape[-1] == 4096:
|
||||
return TEModel.T5_XXL
|
||||
elif weight.shape[-1] == 2048:
|
||||
return TEModel.T5_XL
|
||||
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||
return TEModel.T5_BASE
|
||||
return None
|
||||
|
||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = state_dicts
|
||||
|
||||
class EmptyClass:
|
||||
pass
|
||||
|
||||
@@ -421,39 +447,42 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target = EmptyClass()
|
||||
clip_target.params = {}
|
||||
if len(clip_data) == 1:
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
||||
te_model = detect_te_model(clip_data[0])
|
||||
if te_model == TEModel.CLIP_G:
|
||||
if clip_type == CLIPType.STABLE_CASCADE:
|
||||
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
||||
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
||||
elif clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
||||
elif te_model == TEModel.CLIP_H:
|
||||
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
||||
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
|
||||
elif te_model == TEModel.T5_XXL:
|
||||
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
dtype_t5 = weight.dtype
|
||||
if weight.shape[-1] == 4096:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif weight.shape[-1] == 2048:
|
||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif te_model == TEModel.T5_XL:
|
||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||
elif te_model == TEModel.T5_BASE:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
else:
|
||||
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
|
||||
if w is not None and w.shape[0] == 248:
|
||||
clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel
|
||||
clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
elif len(clip_data) == 2:
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
||||
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HUNYUAN_DIT:
|
||||
clip_target.clip = comfy.text_encoders.hydit.HyditModel
|
||||
@@ -475,10 +504,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
|
||||
parameters = 0
|
||||
tokenizer_data = {}
|
||||
for c in clip_data:
|
||||
parameters += comfy.utils.calculate_parameters(c)
|
||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
|
||||
for c in clip_data:
|
||||
m, u = clip.load_sd(c)
|
||||
if len(m) > 0:
|
||||
@@ -562,7 +593,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
|
||||
if output_model:
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||
model.load_model_weights(sd, diffusion_model_prefix)
|
||||
|
||||
@@ -647,7 +677,10 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||
if model_options.get("fp8_optimizations", False):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
|
||||
@@ -542,6 +542,7 @@ class SD1Tokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
@@ -570,6 +571,7 @@ class SD1ClipModel(torch.nn.Module):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
|
||||
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
||||
|
||||
self.dtypes = set()
|
||||
|
||||
@@ -22,7 +22,8 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
|
||||
class SDXLTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
@@ -40,7 +41,8 @@ class SDXLTokenizer:
|
||||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes = set([dtype])
|
||||
|
||||
@@ -57,7 +59,8 @@ class SDXLClipModel(torch.nn.Module):
|
||||
token_weight_pairs_l = token_weight_pairs["l"]
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
return torch.cat([l_out, g_out], dim=-1), g_pooled
|
||||
cut_to = min(l_out.shape[1], g_out.shape[1])
|
||||
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
|
||||
@@ -49,6 +49,7 @@ class BASE:
|
||||
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config, state_dict=None):
|
||||
|
||||
@@ -13,12 +13,13 @@ class T5XXLModel(sd1_clip.SDClipModel):
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
|
||||
|
||||
class FluxTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
@@ -38,7 +39,8 @@ class FluxClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||
self.dtypes = set([dtype, dtype_t5])
|
||||
|
||||
|
||||
@@ -6,9 +6,9 @@ class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
||||
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
class LongClipModel_(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
def __init__(self, *args, **kwargs):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
||||
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
||||
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
|
||||
|
||||
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -17,3 +17,14 @@ class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
||||
class LongClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
||||
|
||||
def model_options_long_clip(sd, tokenizer_data, model_options):
|
||||
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
|
||||
if w is None:
|
||||
w = sd.get("text_model.embeddings.position_embedding.weight", None)
|
||||
if w is not None and w.shape[0] == 248:
|
||||
tokenizer_data = tokenizer_data.copy()
|
||||
model_options = model_options.copy()
|
||||
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
|
||||
model_options["clip_l_class"] = LongClipModel_
|
||||
return tokenizer_data, model_options
|
||||
|
||||
@@ -20,7 +20,8 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
@@ -42,7 +43,8 @@ class SD3ClipModel(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.dtypes = set()
|
||||
if clip_l:
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_l = None
|
||||
@@ -95,7 +97,8 @@ class SD3ClipModel(torch.nn.Module):
|
||||
if self.clip_g is not None:
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
if lg_out is not None:
|
||||
lg_out = torch.cat([lg_out, g_out], dim=-1)
|
||||
cut_to = min(lg_out.shape[1], g_out.shape[1])
|
||||
lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
|
||||
else:
|
||||
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
||||
else:
|
||||
|
||||
@@ -713,7 +713,9 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
|
||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
|
||||
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
|
||||
return rows * cols
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
@@ -722,10 +724,20 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
||||
|
||||
for b in range(samples.shape[0]):
|
||||
s = samples[b:b+1]
|
||||
|
||||
# handle entire input fitting in a single tile
|
||||
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
|
||||
output[b:b+1] = function(s).to(output_device)
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
|
||||
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
||||
positions = [range(0, s.shape[d+2], tile[d] - overlap) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||
|
||||
for it in itertools.product(*positions):
|
||||
s_in = s
|
||||
upscaled = []
|
||||
|
||||
@@ -734,15 +746,16 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
||||
l = min(tile[d], s.shape[d + 2] - pos)
|
||||
s_in = s_in.narrow(d + 2, pos, l)
|
||||
upscaled.append(round(pos * upscale_amount))
|
||||
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
feather = round(overlap * upscale_amount)
|
||||
|
||||
for t in range(feather):
|
||||
for d in range(2, dims + 2):
|
||||
m = mask.narrow(d, t, 1)
|
||||
m *= ((1.0/feather) * (t + 1))
|
||||
m = mask.narrow(d, mask.shape[d] -1 -t, 1)
|
||||
m *= ((1.0/feather) * (t + 1))
|
||||
a = (t + 1) / feather
|
||||
mask.narrow(d, t, 1).mul_(a)
|
||||
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
||||
|
||||
o = out
|
||||
o_d = out_div
|
||||
@@ -750,8 +763,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
||||
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||
|
||||
o += ps * mask
|
||||
o_d += mask
|
||||
o.add_(ps * mask)
|
||||
o_d.add_(mask)
|
||||
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
import itertools
|
||||
from typing import Sequence, Mapping
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
|
||||
import nodes
|
||||
|
||||
from comfy_execution.graph_utils import is_link
|
||||
|
||||
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
||||
|
||||
|
||||
def include_unique_id_in_input(class_type: str) -> bool:
|
||||
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||
|
||||
class CacheKeySet:
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.keys = {}
|
||||
@@ -98,7 +108,7 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
inputs = node["inputs"]
|
||||
for key in sorted(inputs.keys()):
|
||||
|
||||
@@ -99,30 +99,44 @@ class TopologicalSort:
|
||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
self.add_node(from_node_id)
|
||||
if to_node_id not in self.blocking[from_node_id]:
|
||||
self.blocking[from_node_id][to_node_id] = {}
|
||||
self.blockCount[to_node_id] += 1
|
||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||
if not self.is_cached(from_node_id):
|
||||
self.add_node(from_node_id)
|
||||
if to_node_id not in self.blocking[from_node_id]:
|
||||
self.blocking[from_node_id][to_node_id] = {}
|
||||
self.blockCount[to_node_id] += 1
|
||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||
|
||||
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
|
||||
if unique_id in self.pendingNodes:
|
||||
return
|
||||
self.pendingNodes[unique_id] = True
|
||||
self.blockCount[unique_id] = 0
|
||||
self.blocking[unique_id] = {}
|
||||
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
|
||||
node_ids = [node_unique_id]
|
||||
links = []
|
||||
|
||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||
for input_name in inputs:
|
||||
value = inputs[input_name]
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||
continue
|
||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||
if include_lazy or not is_lazy:
|
||||
self.add_strong_link(from_node_id, from_socket, unique_id)
|
||||
while len(node_ids) > 0:
|
||||
unique_id = node_ids.pop()
|
||||
if unique_id in self.pendingNodes:
|
||||
continue
|
||||
|
||||
self.pendingNodes[unique_id] = True
|
||||
self.blockCount[unique_id] = 0
|
||||
self.blocking[unique_id] = {}
|
||||
|
||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||
for input_name in inputs:
|
||||
value = inputs[input_name]
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||
continue
|
||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||
node_ids.append(from_node_id)
|
||||
links.append((from_node_id, from_socket, unique_id))
|
||||
|
||||
for link in links:
|
||||
self.add_strong_link(*link)
|
||||
|
||||
def is_cached(self, node_id):
|
||||
return False
|
||||
|
||||
def get_ready_nodes(self):
|
||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||
@@ -146,11 +160,8 @@ class ExecutionList(TopologicalSort):
|
||||
self.output_cache = output_cache
|
||||
self.staged_node_id = None
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
if self.output_cache.get(from_node_id) is not None:
|
||||
# Nothing to do
|
||||
return
|
||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
def is_cached(self, node_id):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
|
||||
def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
|
||||
@@ -16,14 +16,15 @@ class EmptyLatentAudio:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}}
|
||||
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def generate(self, seconds):
|
||||
batch_size = 1
|
||||
def generate(self, seconds, batch_size):
|
||||
length = round((seconds * 44100 / 2048) / 2) * 2
|
||||
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
||||
return ({"samples":latent, "type": "audio"}, )
|
||||
@@ -58,6 +59,9 @@ class VAEDecodeAudio:
|
||||
|
||||
def decode(self, vae, samples):
|
||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||
|
||||
|
||||
@@ -183,17 +187,10 @@ class PreviewAudio(SaveAudio):
|
||||
}
|
||||
|
||||
class LoadAudio:
|
||||
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [
|
||||
f for f in os.listdir(input_dir)
|
||||
if (os.path.isfile(os.path.join(input_dir, f))
|
||||
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
|
||||
)
|
||||
]
|
||||
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
||||
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||
import nodes
|
||||
import comfy.utils
|
||||
|
||||
class SetUnionControlNetType:
|
||||
@classmethod
|
||||
@@ -22,6 +24,37 @@ class SetUnionControlNetType:
|
||||
|
||||
return (control_net,)
|
||||
|
||||
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"control_net": ("CONTROL_NET", ),
|
||||
"vae": ("VAE", ),
|
||||
"image": ("IMAGE", ),
|
||||
"mask": ("MASK", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
}}
|
||||
|
||||
FUNCTION = "apply_inpaint_controlnet"
|
||||
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
|
||||
extra_concat = []
|
||||
if control_net.concat_mask:
|
||||
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
|
||||
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
|
||||
extra_concat = [mask]
|
||||
|
||||
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
||||
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SetUnionControlNetType": SetUnionControlNetType,
|
||||
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
|
||||
}
|
||||
|
||||
@@ -90,6 +90,27 @@ class PolyexponentialScheduler:
|
||||
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
||||
return (sigmas, )
|
||||
|
||||
class LaplaceScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||
"mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
|
||||
"beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
|
||||
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
|
||||
return (sigmas, )
|
||||
|
||||
|
||||
class SDTurboScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -673,6 +694,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"KarrasScheduler": KarrasScheduler,
|
||||
"ExponentialScheduler": ExponentialScheduler,
|
||||
"PolyexponentialScheduler": PolyexponentialScheduler,
|
||||
"LaplaceScheduler": LaplaceScheduler,
|
||||
"VPScheduler": VPScheduler,
|
||||
"BetaSamplingScheduler": BetaSamplingScheduler,
|
||||
"SDTurboScheduler": SDTurboScheduler,
|
||||
|
||||
@@ -107,7 +107,7 @@ class HypernetworkLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||||
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
|
||||
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||
model_hypernetwork = model.clone()
|
||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||
if patch is not None:
|
||||
|
||||
115
comfy_extras/nodes_lora_extract.py
Normal file
115
comfy_extras/nodes_lora_extract.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import os
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
def extract_lora(diff, rank):
|
||||
conv2d = (len(diff.shape) == 4)
|
||||
kernel_size = None if not conv2d else diff.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
out_dim, in_dim = diff.size()[0:2]
|
||||
rank = min(rank, in_dim, out_dim)
|
||||
|
||||
if conv2d:
|
||||
if conv2d_3x3:
|
||||
diff = diff.flatten(start_dim=1)
|
||||
else:
|
||||
diff = diff.squeeze()
|
||||
|
||||
|
||||
U, S, Vh = torch.linalg.svd(diff.float())
|
||||
U = U[:, :rank]
|
||||
S = S[:rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
return (U, Vh)
|
||||
|
||||
class LORAType(Enum):
|
||||
STANDARD = 0
|
||||
FULL_DIFF = 1
|
||||
|
||||
LORA_TYPES = {"standard": LORAType.STANDARD,
|
||||
"full_diff": LORAType.FULL_DIFF}
|
||||
|
||||
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
|
||||
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
||||
|
||||
for k in sd:
|
||||
if k.endswith(".weight"):
|
||||
weight_diff = sd[k]
|
||||
if lora_type == LORAType.STANDARD:
|
||||
if weight_diff.ndim < 2:
|
||||
if bias_diff:
|
||||
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||
continue
|
||||
try:
|
||||
out = extract_lora(weight_diff, rank)
|
||||
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
|
||||
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
|
||||
except:
|
||||
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
|
||||
elif lora_type == LORAType.FULL_DIFF:
|
||||
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||
|
||||
elif bias_diff and k.endswith(".bias"):
|
||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||
return output_sd
|
||||
|
||||
class LoraSave:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
||||
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
|
||||
"lora_type": (tuple(LORA_TYPES.keys()),),
|
||||
"bias_diff": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
"optional": {"model_diff": ("MODEL",),
|
||||
"text_encoder_diff": ("CLIP",)},
|
||||
}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
|
||||
if model_diff is None and text_encoder_diff is None:
|
||||
return {}
|
||||
|
||||
lora_type = LORA_TYPES.get(lora_type)
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
|
||||
output_sd = {}
|
||||
if model_diff is not None:
|
||||
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
|
||||
if text_encoder_diff is not None:
|
||||
output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||
return {}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LoraSave": LoraSave
|
||||
}
|
||||
@@ -17,7 +17,7 @@ class PatchModelAddDownscale:
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
|
||||
@@ -26,6 +26,7 @@ class PerpNeg:
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
DEPRECATED = True
|
||||
|
||||
def patch(self, model, empty_conditioning, neg_scale):
|
||||
m = model.clone()
|
||||
|
||||
@@ -126,7 +126,7 @@ class PhotoMakerLoader:
|
||||
CATEGORY = "_for_testing/photomaker"
|
||||
|
||||
def load_photomaker_model(self, photomaker_model_name):
|
||||
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
|
||||
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
|
||||
photomaker_model = PhotoMakerIDEncoder()
|
||||
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
|
||||
if "id_encoder" in data:
|
||||
|
||||
@@ -15,9 +15,9 @@ class TripleCLIPLoader:
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
||||
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
||||
clip_path3 = folder_paths.get_full_path("clip", clip_name3)
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
|
||||
clip_path3 = folder_paths.get_full_path_or_raise("clip", clip_name3)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return (clip,)
|
||||
|
||||
@@ -36,7 +36,7 @@ class EmptySD3LatentImage:
|
||||
CATEGORY = "latent/sd3"
|
||||
|
||||
def generate(self, width, height, batch_size=1):
|
||||
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609
|
||||
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
|
||||
return ({"samples":latent}, )
|
||||
|
||||
class CLIPTextEncodeSD3:
|
||||
@@ -93,6 +93,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
}}
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
DEPRECATED = True
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TripleCLIPLoader": TripleCLIPLoader,
|
||||
@@ -103,5 +104,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
# Sampling
|
||||
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
|
||||
"ControlNetApplySD3": "Apply Controlnet with VAE",
|
||||
}
|
||||
|
||||
@@ -116,6 +116,7 @@ class StableCascade_SuperResolutionControlnet:
|
||||
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
|
||||
FUNCTION = "generate"
|
||||
|
||||
EXPERIMENTAL = True
|
||||
CATEGORY = "_for_testing/stable_cascade"
|
||||
|
||||
def generate(self, image, vae):
|
||||
|
||||
@@ -154,7 +154,7 @@ class TomePatchModel:
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, ratio):
|
||||
self.u = None
|
||||
|
||||
22
comfy_extras/nodes_torch_compile.py
Normal file
22
comfy_extras/nodes_torch_compile.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
class TorchCompileModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"backend": (["inductor", "cudagraphs"],),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(self, model, backend):
|
||||
m = model.clone()
|
||||
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TorchCompileModel": TorchCompileModel,
|
||||
}
|
||||
@@ -25,7 +25,7 @@ class UpscaleModelLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_model(self, model_name):
|
||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
||||
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
|
||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
|
||||
|
||||
@@ -17,7 +17,7 @@ class ImageOnlyCheckpointLoader:
|
||||
CATEGORY = "loaders/video_models"
|
||||
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return (out[0], out[3], out[2])
|
||||
|
||||
@@ -107,7 +107,7 @@ class VideoTriangleCFGGuidance:
|
||||
return (m, )
|
||||
|
||||
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
|
||||
@@ -37,6 +37,7 @@ class SaveImageWebsocket:
|
||||
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, images):
|
||||
return time.time()
|
||||
|
||||
|
||||
@@ -179,7 +179,13 @@ def merge_result_data(results, obj):
|
||||
# merge node execution results
|
||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||
if is_list:
|
||||
output.append([x for o in results for x in o[i]])
|
||||
value = []
|
||||
for o in results:
|
||||
if isinstance(o[i], ExecutionBlocker):
|
||||
value.append(o[i])
|
||||
else:
|
||||
value.extend(o[i])
|
||||
output.append(value)
|
||||
else:
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
@@ -25,11 +25,16 @@ a111:
|
||||
|
||||
#comfyui:
|
||||
# base_path: path/to/comfyui/
|
||||
# # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
|
||||
# #is_default: true
|
||||
# checkpoints: models/checkpoints/
|
||||
# clip: models/clip/
|
||||
# clip_vision: models/clip_vision/
|
||||
# configs: models/configs/
|
||||
# controlnet: models/controlnet/
|
||||
# diffusion_models: |
|
||||
# models/diffusion_models
|
||||
# models/unet
|
||||
# embeddings: models/embeddings/
|
||||
# loras: models/loras/
|
||||
# upscale_models: models/upscale_models/
|
||||
|
||||
103
folder_paths.py
103
folder_paths.py
@@ -2,7 +2,9 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
import mimetypes
|
||||
import logging
|
||||
from typing import Set, List, Dict, Tuple, Literal
|
||||
from collections.abc import Collection
|
||||
|
||||
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||
@@ -44,6 +46,40 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
|
||||
|
||||
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||
|
||||
class CacheHelper:
|
||||
"""
|
||||
Helper class for managing file list cache data.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||
self.active = False
|
||||
|
||||
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
|
||||
if not self.active:
|
||||
return default
|
||||
return self.cache.get(key, default)
|
||||
|
||||
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
|
||||
if self.active:
|
||||
self.cache[key] = value
|
||||
|
||||
def clear(self):
|
||||
self.cache.clear()
|
||||
|
||||
def __enter__(self):
|
||||
self.active = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.active = False
|
||||
self.clear()
|
||||
|
||||
cache_helper = CacheHelper()
|
||||
|
||||
extension_mimetypes_cache = {
|
||||
"webp" : "image",
|
||||
}
|
||||
|
||||
def map_legacy(folder_name: str) -> str:
|
||||
legacy = {"unet": "diffusion_models"}
|
||||
return legacy.get(folder_name, folder_name)
|
||||
@@ -78,6 +114,13 @@ def get_input_directory() -> str:
|
||||
global input_directory
|
||||
return input_directory
|
||||
|
||||
def get_user_directory() -> str:
|
||||
return user_directory
|
||||
|
||||
def set_user_directory(user_dir: str) -> None:
|
||||
global user_directory
|
||||
user_directory = user_dir
|
||||
|
||||
|
||||
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
||||
def get_directory_by_type(type_name: str) -> str | None:
|
||||
@@ -89,6 +132,28 @@ def get_directory_by_type(type_name: str) -> str | None:
|
||||
return get_input_directory()
|
||||
return None
|
||||
|
||||
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
|
||||
"""
|
||||
Example:
|
||||
files = os.listdir(folder_paths.get_input_directory())
|
||||
filter_files_content_types(files, ["image", "audio", "video"])
|
||||
"""
|
||||
global extension_mimetypes_cache
|
||||
result = []
|
||||
for file in files:
|
||||
extension = file.split('.')[-1]
|
||||
if extension not in extension_mimetypes_cache:
|
||||
mime_type, _ = mimetypes.guess_type(file, strict=False)
|
||||
if not mime_type:
|
||||
continue
|
||||
content_type = mime_type.split('/')[0]
|
||||
extension_mimetypes_cache[extension] = content_type
|
||||
else:
|
||||
content_type = extension_mimetypes_cache[extension]
|
||||
|
||||
if content_type in content_types:
|
||||
result.append(file)
|
||||
return result
|
||||
|
||||
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
||||
# otherwise use default_path as base_dir
|
||||
@@ -130,11 +195,14 @@ def exists_annotated_filepath(name) -> bool:
|
||||
return os.path.exists(filepath)
|
||||
|
||||
|
||||
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
|
||||
def add_model_folder_path(folder_name: str, full_folder_path: str, is_default: bool = False) -> None:
|
||||
global folder_names_and_paths
|
||||
folder_name = map_legacy(folder_name)
|
||||
if folder_name in folder_names_and_paths:
|
||||
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
||||
if is_default:
|
||||
folder_names_and_paths[folder_name][0].insert(0, full_folder_path)
|
||||
else:
|
||||
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
||||
else:
|
||||
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
||||
|
||||
@@ -166,8 +234,12 @@ def recursive_search(directory: str, excluded_dir_names: list[str] | None=None)
|
||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||
for file_name in filenames:
|
||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||
result.append(relative_path)
|
||||
try:
|
||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||
result.append(relative_path)
|
||||
except:
|
||||
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||
continue
|
||||
|
||||
for d in subdirs:
|
||||
path: str = os.path.join(dirpath, d)
|
||||
@@ -200,6 +272,14 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||
full_path = get_full_path(folder_name, filename)
|
||||
if full_path is None:
|
||||
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
||||
return full_path
|
||||
|
||||
|
||||
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
||||
folder_name = map_legacy(folder_name)
|
||||
global folder_names_and_paths
|
||||
@@ -214,6 +294,10 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
|
||||
return sorted(list(output_list)), output_folders, time.perf_counter()
|
||||
|
||||
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
||||
strong_cache = cache_helper.get(folder_name)
|
||||
if strong_cache is not None:
|
||||
return strong_cache
|
||||
|
||||
global filename_list_cache
|
||||
global folder_names_and_paths
|
||||
folder_name = map_legacy(folder_name)
|
||||
@@ -242,6 +326,7 @@ def get_filename_list(folder_name: str) -> list[str]:
|
||||
out = get_filename_list_(folder_name)
|
||||
global filename_list_cache
|
||||
filename_list_cache[folder_name] = out
|
||||
cache_helper.set(folder_name, out)
|
||||
return list(out[0])
|
||||
|
||||
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
|
||||
@@ -257,9 +342,17 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
||||
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
||||
input = input.replace("%width%", str(image_width))
|
||||
input = input.replace("%height%", str(image_height))
|
||||
now = time.localtime()
|
||||
input = input.replace("%year%", str(now.tm_year))
|
||||
input = input.replace("%month%", str(now.tm_mon).zfill(2))
|
||||
input = input.replace("%day%", str(now.tm_mday).zfill(2))
|
||||
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
|
||||
input = input.replace("%minute%", str(now.tm_min).zfill(2))
|
||||
input = input.replace("%second%", str(now.tm_sec).zfill(2))
|
||||
return input
|
||||
|
||||
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||
if "%" in filename_prefix:
|
||||
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||
|
||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||
|
||||
@@ -9,7 +9,7 @@ import folder_paths
|
||||
import comfy.utils
|
||||
import logging
|
||||
|
||||
MAX_PREVIEW_RESOLUTION = 512
|
||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||
|
||||
def preview_to_image(latent_image):
|
||||
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
@@ -36,12 +36,20 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
||||
|
||||
|
||||
class Latent2RGBPreviewer(LatentPreviewer):
|
||||
def __init__(self, latent_rgb_factors):
|
||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
|
||||
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
|
||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
||||
self.latent_rgb_factors_bias = None
|
||||
if latent_rgb_factors_bias is not None:
|
||||
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
||||
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
||||
if self.latent_rgb_factors_bias is not None:
|
||||
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||
|
||||
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
||||
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
||||
|
||||
return preview_to_image(latent_image)
|
||||
|
||||
|
||||
@@ -71,7 +79,7 @@ def get_previewer(device, latent_format):
|
||||
|
||||
if previewer is None:
|
||||
if latent_format.latent_rgb_factors is not None:
|
||||
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
|
||||
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
|
||||
return previewer
|
||||
|
||||
def prepare_callback(model, steps, x0_output_dict=None):
|
||||
|
||||
43
main.py
43
main.py
@@ -9,7 +9,7 @@ from comfy.cli_args import args
|
||||
from app.logger import setup_logger
|
||||
|
||||
|
||||
setup_logger(verbose=args.verbose)
|
||||
setup_logger(log_level=args.verbose)
|
||||
|
||||
|
||||
def execute_prestartup_script():
|
||||
@@ -63,6 +63,7 @@ import threading
|
||||
import gc
|
||||
|
||||
import logging
|
||||
import utils.extra_config
|
||||
|
||||
if os.name == "nt":
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
@@ -85,7 +86,6 @@ if args.windows_standalone_build:
|
||||
pass
|
||||
|
||||
import comfy.utils
|
||||
import yaml
|
||||
|
||||
import execution
|
||||
import server
|
||||
@@ -160,7 +160,10 @@ def prompt_worker(q, server):
|
||||
need_gc = False
|
||||
|
||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||
addresses = []
|
||||
for addr in address.split(","):
|
||||
addresses.append((addr, port))
|
||||
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
|
||||
|
||||
|
||||
def hijack_progress(server):
|
||||
@@ -180,27 +183,6 @@ def cleanup_temp():
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def load_extra_path_config(yaml_path):
|
||||
with open(yaml_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
for c in config:
|
||||
conf = config[c]
|
||||
if conf is None:
|
||||
continue
|
||||
base_path = None
|
||||
if "base_path" in conf:
|
||||
base_path = conf.pop("base_path")
|
||||
for x in conf:
|
||||
for y in conf[x].split("\n"):
|
||||
if len(y) == 0:
|
||||
continue
|
||||
full_path = y
|
||||
if base_path is not None:
|
||||
full_path = os.path.join(base_path, full_path)
|
||||
logging.info("Adding extra search path {} {}".format(x, full_path))
|
||||
folder_paths.add_model_folder_path(x, full_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.temp_directory:
|
||||
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
||||
@@ -222,11 +204,11 @@ if __name__ == "__main__":
|
||||
|
||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||
if os.path.isfile(extra_model_paths_config_path):
|
||||
load_extra_path_config(extra_model_paths_config_path)
|
||||
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
|
||||
|
||||
if args.extra_model_paths_config:
|
||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||
load_extra_path_config(config_path)
|
||||
utils.extra_config.load_extra_path_config(config_path)
|
||||
|
||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
||||
|
||||
@@ -247,21 +229,30 @@ if __name__ == "__main__":
|
||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
||||
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
|
||||
|
||||
if args.input_directory:
|
||||
input_dir = os.path.abspath(args.input_directory)
|
||||
logging.info(f"Setting input directory to: {input_dir}")
|
||||
folder_paths.set_input_directory(input_dir)
|
||||
|
||||
if args.user_directory:
|
||||
user_dir = os.path.abspath(args.user_directory)
|
||||
logging.info(f"Setting user directory to: {user_dir}")
|
||||
folder_paths.set_user_directory(user_dir)
|
||||
|
||||
if args.quick_test_for_ci:
|
||||
exit(0)
|
||||
|
||||
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
||||
call_on_start = None
|
||||
if args.auto_launch:
|
||||
def startup_server(scheme, address, port):
|
||||
import webbrowser
|
||||
if os.name == 'nt' and address == '0.0.0.0':
|
||||
address = '127.0.0.1'
|
||||
if ':' in address:
|
||||
address = "[{}]".format(address)
|
||||
webbrowser.open(f"{scheme}://{address}:{port}")
|
||||
call_on_start = startup_server
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# model_manager/__init__.py
|
||||
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename
|
||||
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#NOTE: This was an experiment and WILL BE REMOVED
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import os
|
||||
import traceback
|
||||
import logging
|
||||
from folder_paths import models_dir
|
||||
from folder_paths import folder_names_and_paths, get_folder_paths
|
||||
import re
|
||||
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||
from enum import Enum
|
||||
@@ -17,6 +18,7 @@ class DownloadStatusType(Enum):
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadModelStatus():
|
||||
status: str
|
||||
@@ -29,7 +31,7 @@ class DownloadModelStatus():
|
||||
self.progress_percentage = progress_percentage
|
||||
self.message = message
|
||||
self.already_existed = already_existed
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"status": self.status,
|
||||
@@ -38,102 +40,112 @@ class DownloadModelStatus():
|
||||
"already_existed": self.already_existed
|
||||
}
|
||||
|
||||
|
||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||
model_name: str,
|
||||
model_url: str,
|
||||
model_sub_directory: str,
|
||||
model_name: str,
|
||||
model_url: str,
|
||||
model_directory: str,
|
||||
folder_path: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
progress_interval: float = 1.0) -> DownloadModelStatus:
|
||||
"""
|
||||
Download a model file from a given URL into the models directory.
|
||||
|
||||
Args:
|
||||
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
||||
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
||||
A function that makes an HTTP request. This makes it easier to mock in unit tests.
|
||||
model_name (str):
|
||||
model_name (str):
|
||||
The name of the model file to be downloaded. This will be the filename on disk.
|
||||
model_url (str):
|
||||
model_url (str):
|
||||
The URL from which to download the model.
|
||||
model_sub_directory (str):
|
||||
The subdirectory within the main models directory where the model
|
||||
model_directory (str):
|
||||
The subdirectory within the main models directory where the model
|
||||
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
||||
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
||||
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
||||
An asynchronous function to call with progress updates.
|
||||
folder_path (str);
|
||||
Path to which model folder should be used as the root.
|
||||
|
||||
Returns:
|
||||
DownloadModelStatus: The result of the download operation.
|
||||
"""
|
||||
if not validate_model_subdirectory(model_sub_directory):
|
||||
return DownloadModelStatus(
|
||||
DownloadStatusType.ERROR,
|
||||
0,
|
||||
"Invalid model subdirectory",
|
||||
False
|
||||
)
|
||||
|
||||
if not validate_filename(model_name):
|
||||
return DownloadModelStatus(
|
||||
DownloadStatusType.ERROR,
|
||||
DownloadStatusType.ERROR,
|
||||
0,
|
||||
"Invalid model name",
|
||||
"Invalid model name",
|
||||
False
|
||||
)
|
||||
|
||||
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
|
||||
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
|
||||
if not model_directory in folder_names_and_paths:
|
||||
return DownloadModelStatus(
|
||||
DownloadStatusType.ERROR,
|
||||
0,
|
||||
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
|
||||
False
|
||||
)
|
||||
|
||||
if not folder_path in get_folder_paths(model_directory):
|
||||
return DownloadModelStatus(
|
||||
DownloadStatusType.ERROR,
|
||||
0,
|
||||
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
|
||||
False
|
||||
)
|
||||
|
||||
file_path = create_model_path(model_name, folder_path)
|
||||
existing_file = await check_file_exists(file_path, model_name, progress_callback)
|
||||
if existing_file:
|
||||
return existing_file
|
||||
|
||||
try:
|
||||
logging.info(f"Downloading {model_name} from {model_url}")
|
||||
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
||||
await progress_callback(relative_path, status)
|
||||
await progress_callback(model_name, status)
|
||||
|
||||
response = await model_download_request(model_url)
|
||||
if response.status != 200:
|
||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||
logging.error(error_message)
|
||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||
await progress_callback(relative_path, status)
|
||||
await progress_callback(model_name, status)
|
||||
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||
|
||||
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
|
||||
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in downloading model: {e}")
|
||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||
|
||||
return await handle_download_error(e, model_name, progress_callback)
|
||||
|
||||
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
||||
full_model_dir = os.path.join(models_base_dir, model_directory)
|
||||
os.makedirs(full_model_dir, exist_ok=True)
|
||||
file_path = os.path.join(full_model_dir, model_name)
|
||||
|
||||
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
file_path = os.path.join(folder_path, model_name)
|
||||
|
||||
# Ensure the resulting path is still within the base directory
|
||||
abs_file_path = os.path.abspath(file_path)
|
||||
abs_base_dir = os.path.abspath(str(models_base_dir))
|
||||
abs_base_dir = os.path.abspath(folder_path)
|
||||
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
||||
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
|
||||
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
relative_path = '/'.join([model_directory, model_name])
|
||||
return file_path, relative_path
|
||||
|
||||
async def check_file_exists(file_path: str,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
relative_path: str) -> Optional[DownloadModelStatus]:
|
||||
async def check_file_exists(file_path: str,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
|
||||
) -> Optional[DownloadModelStatus]:
|
||||
if os.path.exists(file_path):
|
||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
||||
await progress_callback(relative_path, status)
|
||||
await progress_callback(model_name, status)
|
||||
return status
|
||||
return None
|
||||
|
||||
|
||||
async def track_download_progress(response: aiohttp.ClientResponse,
|
||||
file_path: str,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
relative_path: str,
|
||||
async def track_download_progress(response: aiohttp.ClientResponse,
|
||||
file_path: str,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
interval: float = 1.0) -> DownloadModelStatus:
|
||||
try:
|
||||
total_size = int(response.headers.get('Content-Length', 0))
|
||||
@@ -144,10 +156,11 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
||||
nonlocal last_update_time
|
||||
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
||||
await progress_callback(relative_path, status)
|
||||
await progress_callback(model_name, status)
|
||||
last_update_time = time.time()
|
||||
|
||||
with open(file_path, 'wb') as f:
|
||||
temp_file_path = file_path + '.tmp'
|
||||
with open(temp_file_path, 'wb') as f:
|
||||
chunk_iterator = response.content.iter_chunked(8192)
|
||||
while True:
|
||||
try:
|
||||
@@ -156,58 +169,39 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
||||
break
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
|
||||
|
||||
if time.time() - last_update_time >= interval:
|
||||
await update_progress()
|
||||
|
||||
os.rename(temp_file_path, file_path)
|
||||
|
||||
await update_progress()
|
||||
|
||||
|
||||
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
||||
await progress_callback(relative_path, status)
|
||||
await progress_callback(model_name, status)
|
||||
|
||||
return status
|
||||
except Exception as e:
|
||||
logging.error(f"Error in track_download_progress: {e}")
|
||||
logging.error(traceback.format_exc())
|
||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||
return await handle_download_error(e, model_name, progress_callback)
|
||||
|
||||
async def handle_download_error(e: Exception,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
||||
relative_path: str) -> DownloadModelStatus:
|
||||
|
||||
async def handle_download_error(e: Exception,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Any]
|
||||
) -> DownloadModelStatus:
|
||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||
await progress_callback(relative_path, status)
|
||||
await progress_callback(model_name, status)
|
||||
return status
|
||||
|
||||
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
||||
"""
|
||||
Validate that the model subdirectory is safe to install into.
|
||||
Must not contain relative paths, nested paths or special characters
|
||||
other than underscores and hyphens.
|
||||
|
||||
Args:
|
||||
model_subdirectory (str): The subdirectory for the specific model type.
|
||||
|
||||
Returns:
|
||||
bool: True if the subdirectory is safe, False otherwise.
|
||||
"""
|
||||
if len(model_subdirectory) > 50:
|
||||
return False
|
||||
|
||||
if '..' in model_subdirectory or '/' in model_subdirectory:
|
||||
return False
|
||||
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def validate_filename(filename: str)-> bool:
|
||||
"""
|
||||
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
||||
|
||||
|
||||
Args:
|
||||
filename (str): The filename to validate
|
||||
|
||||
|
||||
54
nodes.py
54
nodes.py
@@ -511,10 +511,11 @@ class CheckpointLoader:
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
DEPRECATED = True
|
||||
|
||||
def load_checkpoint(self, config_name, ckpt_name):
|
||||
config_path = folder_paths.get_full_path("configs", config_name)
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
|
||||
class CheckpointLoaderSimple:
|
||||
@@ -535,7 +536,7 @@ class CheckpointLoaderSimple:
|
||||
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
||||
|
||||
def load_checkpoint(self, ckpt_name):
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return out[:3]
|
||||
|
||||
@@ -577,7 +578,7 @@ class unCLIPCheckpointLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return out
|
||||
|
||||
@@ -624,7 +625,7 @@ class LoraLoader:
|
||||
if strength_model == 0 and strength_clip == 0:
|
||||
return (model, clip)
|
||||
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
||||
lora = None
|
||||
if self.loaded_lora is not None:
|
||||
if self.loaded_lora[0] == lora_path:
|
||||
@@ -703,11 +704,11 @@ class VAELoader:
|
||||
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
||||
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
||||
|
||||
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
|
||||
enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
|
||||
for k in enc:
|
||||
sd["taesd_encoder.{}".format(k)] = enc[k]
|
||||
|
||||
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
|
||||
dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
|
||||
for k in dec:
|
||||
sd["taesd_decoder.{}".format(k)] = dec[k]
|
||||
|
||||
@@ -738,7 +739,7 @@ class VAELoader:
|
||||
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||
sd = self.load_taesd(vae_name)
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = comfy.sd.VAE(sd=sd)
|
||||
return (vae,)
|
||||
@@ -754,7 +755,7 @@ class ControlNetLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_controlnet(self, control_net_name):
|
||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
||||
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
|
||||
return (controlnet,)
|
||||
|
||||
@@ -770,7 +771,7 @@ class DiffControlNetLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_controlnet(self, model, control_net_name):
|
||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
||||
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
|
||||
return (controlnet,)
|
||||
|
||||
@@ -786,6 +787,7 @@ class ControlNetApply:
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "apply_controlnet"
|
||||
|
||||
DEPRECATED = True
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||
@@ -815,7 +817,10 @@ class ControlNetApplyAdvanced:
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
}}
|
||||
},
|
||||
"optional": {"vae": ("VAE", ),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
@@ -823,7 +828,7 @@ class ControlNetApplyAdvanced:
|
||||
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
|
||||
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
|
||||
if strength == 0:
|
||||
return (positive, negative)
|
||||
|
||||
@@ -840,7 +845,7 @@ class ControlNetApplyAdvanced:
|
||||
if prev_cnet in cnets:
|
||||
c_net = cnets[prev_cnet]
|
||||
else:
|
||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
|
||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat)
|
||||
c_net.set_previous_controlnet(prev_cnet)
|
||||
cnets[prev_cnet] = c_net
|
||||
|
||||
@@ -856,7 +861,7 @@ class UNETLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_unet"
|
||||
@@ -867,10 +872,13 @@ class UNETLoader:
|
||||
model_options = {}
|
||||
if weight_dtype == "fp8_e4m3fn":
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
elif weight_dtype == "fp8_e4m3fn_fast":
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
model_options["fp8_optimizations"] = True
|
||||
elif weight_dtype == "fp8_e5m2":
|
||||
model_options["dtype"] = torch.float8_e5m2
|
||||
|
||||
unet_path = folder_paths.get_full_path("diffusion_models", unet_name)
|
||||
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
|
||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
return (model,)
|
||||
|
||||
@@ -895,7 +903,7 @@ class CLIPLoader:
|
||||
else:
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
|
||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
||||
clip_path = folder_paths.get_full_path_or_raise("clip", clip_name)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||
return (clip,)
|
||||
|
||||
@@ -912,8 +920,8 @@ class DualCLIPLoader:
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, type):
|
||||
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
|
||||
if type == "sdxl":
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
elif type == "sd3":
|
||||
@@ -935,7 +943,7 @@ class CLIPVisionLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_clip(self, clip_name):
|
||||
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
|
||||
clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
|
||||
clip_vision = comfy.clip_vision.load(clip_path)
|
||||
return (clip_vision,)
|
||||
|
||||
@@ -965,7 +973,7 @@ class StyleModelLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_style_model(self, style_model_name):
|
||||
style_model_path = folder_paths.get_full_path("style_models", style_model_name)
|
||||
style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
|
||||
style_model = comfy.sd.load_style_model(style_model_path)
|
||||
return (style_model,)
|
||||
|
||||
@@ -1030,7 +1038,7 @@ class GLIGENLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_gligen(self, gligen_name):
|
||||
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
|
||||
gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
|
||||
gligen = comfy.sd.load_gligen(gligen_path)
|
||||
return (gligen,)
|
||||
|
||||
@@ -1916,8 +1924,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ConditioningSetArea": "Conditioning (Set Area)",
|
||||
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
||||
"ConditioningSetMask": "Conditioning (Set Mask)",
|
||||
"ControlNetApply": "Apply ControlNet",
|
||||
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
|
||||
"ControlNetApply": "Apply ControlNet (OLD)",
|
||||
"ControlNetApplyAdvanced": "Apply ControlNet",
|
||||
# Latent
|
||||
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
||||
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
||||
@@ -2101,6 +2109,8 @@ def init_builtin_extra_nodes():
|
||||
"nodes_controlnet.py",
|
||||
"nodes_hunyuan.py",
|
||||
"nodes_flux.py",
|
||||
"nodes_lora_extract.py",
|
||||
"nodes_torch_compile.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@@ -38,6 +38,9 @@ def get_images(ws, prompt):
|
||||
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||||
break #Execution is done
|
||||
else:
|
||||
# If you want to be able to decode the binary stream for latent previews, here is how you can do it:
|
||||
# bytesIO = BytesIO(out[8:])
|
||||
# preview_image = Image.open(bytesIO) # This is your preview in PIL image format, store it in a global
|
||||
continue #previews are binary data
|
||||
|
||||
history = get_history(prompt_id)[prompt_id]
|
||||
@@ -151,7 +154,7 @@ prompt["3"]["inputs"]["seed"] = 5
|
||||
ws = websocket.WebSocket()
|
||||
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
||||
images = get_images(ws, prompt)
|
||||
|
||||
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
|
||||
#Commented out code to display the output images:
|
||||
|
||||
# for node_id in images:
|
||||
|
||||
@@ -147,7 +147,7 @@ prompt["3"]["inputs"]["seed"] = 5
|
||||
ws = websocket.WebSocket()
|
||||
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
||||
images = get_images(ws, prompt)
|
||||
|
||||
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
|
||||
#Commented out code to display the output images:
|
||||
|
||||
# for node_id in images:
|
||||
|
||||
128
server.py
128
server.py
@@ -12,6 +12,8 @@ import json
|
||||
import glob
|
||||
import struct
|
||||
import ssl
|
||||
import socket
|
||||
import ipaddress
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from io import BytesIO
|
||||
@@ -80,6 +82,68 @@ def create_cors_middleware(allowed_origin: str):
|
||||
|
||||
return cors_middleware
|
||||
|
||||
def is_loopback(host):
|
||||
if host is None:
|
||||
return False
|
||||
try:
|
||||
if ipaddress.ip_address(host).is_loopback:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except:
|
||||
pass
|
||||
|
||||
loopback = False
|
||||
for family in (socket.AF_INET, socket.AF_INET6):
|
||||
try:
|
||||
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
|
||||
for family, _, _, _, sockaddr in r:
|
||||
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
|
||||
return loopback
|
||||
else:
|
||||
loopback = True
|
||||
except socket.gaierror:
|
||||
pass
|
||||
|
||||
return loopback
|
||||
|
||||
|
||||
def create_origin_only_middleware():
|
||||
@web.middleware
|
||||
async def origin_only_middleware(request: web.Request, handler):
|
||||
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
|
||||
#in that case the Host and Origin hostnames won't match
|
||||
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
|
||||
if 'Host' in request.headers and 'Origin' in request.headers:
|
||||
host = request.headers['Host']
|
||||
origin = request.headers['Origin']
|
||||
host_domain = host.lower()
|
||||
parsed = urllib.parse.urlparse(origin)
|
||||
origin_domain = parsed.netloc.lower()
|
||||
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
|
||||
|
||||
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
|
||||
loopback = is_loopback(host_domain_parsed.hostname)
|
||||
|
||||
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
|
||||
host_domain = host_domain_parsed.hostname
|
||||
if host_domain_parsed.port is None:
|
||||
origin_domain = parsed.hostname
|
||||
|
||||
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
|
||||
if host_domain != origin_domain:
|
||||
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
||||
return web.Response(status=403)
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
response = web.Response()
|
||||
else:
|
||||
response = await handler(request)
|
||||
|
||||
return response
|
||||
|
||||
return origin_only_middleware
|
||||
|
||||
class PromptServer():
|
||||
def __init__(self, loop):
|
||||
PromptServer.instance = self
|
||||
@@ -99,6 +163,8 @@ class PromptServer():
|
||||
middlewares = [cache_control]
|
||||
if args.enable_cors_header:
|
||||
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
||||
else:
|
||||
middlewares.append(create_origin_only_middleware())
|
||||
|
||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||
@@ -155,6 +221,12 @@ class PromptServer():
|
||||
def get_embeddings(self):
|
||||
embeddings = folder_paths.get_filename_list("embeddings")
|
||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||
|
||||
@routes.get("/models")
|
||||
def list_model_types(request):
|
||||
model_types = list(folder_paths.folder_names_and_paths.keys())
|
||||
|
||||
return web.json_response(model_types)
|
||||
|
||||
@routes.get("/models/{folder}")
|
||||
async def get_models(request):
|
||||
@@ -418,12 +490,17 @@ class PromptServer():
|
||||
async def system_stats(request):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device_name = comfy.model_management.get_torch_device_name(device)
|
||||
cpu_device = comfy.model_management.torch.device("cpu")
|
||||
ram_total = comfy.model_management.get_total_memory(cpu_device)
|
||||
ram_free = comfy.model_management.get_free_memory(cpu_device)
|
||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||
|
||||
system_stats = {
|
||||
"system": {
|
||||
"os": os.name,
|
||||
"ram_total": ram_total,
|
||||
"ram_free": ram_free,
|
||||
"comfyui_version": get_comfyui_version(),
|
||||
"python_version": sys.version,
|
||||
"pytorch_version": comfy.model_management.torch_version,
|
||||
@@ -480,14 +557,15 @@ class PromptServer():
|
||||
|
||||
@routes.get("/object_info")
|
||||
async def get_object_info(request):
|
||||
out = {}
|
||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||
try:
|
||||
out[x] = node_info(x)
|
||||
except Exception as e:
|
||||
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
||||
logging.error(traceback.format_exc())
|
||||
return web.json_response(out)
|
||||
with folder_paths.cache_helper:
|
||||
out = {}
|
||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||
try:
|
||||
out[x] = node_info(x)
|
||||
except Exception as e:
|
||||
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
||||
logging.error(traceback.format_exc())
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/object_info/{node_class}")
|
||||
async def get_object_info_node(request):
|
||||
@@ -601,6 +679,7 @@ class PromptServer():
|
||||
|
||||
# Internal route. Should not be depended upon and is subject to change at any time.
|
||||
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
|
||||
# NOTE: This was an experiment and WILL BE REMOVED
|
||||
@routes.post("/internal/models/download")
|
||||
async def download_handler(request):
|
||||
async def report_progress(filename: str, status: DownloadModelStatus):
|
||||
@@ -611,10 +690,11 @@ class PromptServer():
|
||||
data = await request.json()
|
||||
url = data.get('url')
|
||||
model_directory = data.get('model_directory')
|
||||
folder_path = data.get('folder_path')
|
||||
model_filename = data.get('model_filename')
|
||||
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
||||
|
||||
if not url or not model_directory or not model_filename:
|
||||
if not url or not model_directory or not model_filename or not folder_path:
|
||||
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||
|
||||
session = self.client_session
|
||||
@@ -622,7 +702,7 @@ class PromptServer():
|
||||
logging.error("Client session is not initialized")
|
||||
return web.Response(status=500)
|
||||
|
||||
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
|
||||
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
|
||||
await task
|
||||
|
||||
return web.json_response(task.result().to_dict())
|
||||
@@ -739,6 +819,9 @@ class PromptServer():
|
||||
await self.send(*msg)
|
||||
|
||||
async def start(self, address, port, verbose=True, call_on_start=None):
|
||||
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
|
||||
|
||||
async def start_multi_address(self, addresses, call_on_start=None):
|
||||
runner = web.AppRunner(self.app, access_log=None)
|
||||
await runner.setup()
|
||||
ssl_ctx = None
|
||||
@@ -749,17 +832,26 @@ class PromptServer():
|
||||
keyfile=args.tls_keyfile)
|
||||
scheme = "https"
|
||||
|
||||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||||
await site.start()
|
||||
logging.info("Starting server\n")
|
||||
for addr in addresses:
|
||||
address = addr[0]
|
||||
port = addr[1]
|
||||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||||
await site.start()
|
||||
|
||||
self.address = address
|
||||
self.port = port
|
||||
if not hasattr(self, 'address'):
|
||||
self.address = address #TODO: remove this
|
||||
self.port = port
|
||||
|
||||
if ':' in address:
|
||||
address_print = "[{}]".format(address)
|
||||
else:
|
||||
address_print = address
|
||||
|
||||
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
|
||||
|
||||
if verbose:
|
||||
logging.info("Starting server\n")
|
||||
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
|
||||
if call_on_start is not None:
|
||||
call_on_start(scheme, address, port)
|
||||
call_on_start(scheme, self.address, self.port)
|
||||
|
||||
def add_on_prompt_handler(self, handler):
|
||||
self.on_prompt_handlers.append(handler)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
## Install test dependencies
|
||||
|
||||
`pip install -r tests-units/requirements.txt`
|
||||
`pip install -r tests-unit/requirements.txt`
|
||||
|
||||
## Run tests
|
||||
`pytest tests-units/`
|
||||
`pytest tests-unit/`
|
||||
|
||||
66
tests-unit/comfy_test/folder_path_test.py
Normal file
66
tests-unit/comfy_test/folder_path_test.py
Normal file
@@ -0,0 +1,66 @@
|
||||
### 🗻 This file is created through the spirit of Mount Fuji at its peak
|
||||
# TODO(yoland): clean up this after I get back down
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
import folder_paths
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
yield tmpdirname
|
||||
|
||||
|
||||
def test_get_directory_by_type():
|
||||
test_dir = "/test/dir"
|
||||
folder_paths.set_output_directory(test_dir)
|
||||
assert folder_paths.get_directory_by_type("output") == test_dir
|
||||
assert folder_paths.get_directory_by_type("invalid") is None
|
||||
|
||||
def test_annotated_filepath():
|
||||
assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None)
|
||||
assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory())
|
||||
assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory())
|
||||
assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory())
|
||||
|
||||
def test_get_annotated_filepath():
|
||||
default_dir = "/default/dir"
|
||||
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
|
||||
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
|
||||
|
||||
def test_add_model_folder_path():
|
||||
folder_paths.add_model_folder_path("test_folder", "/test/path")
|
||||
assert "/test/path" in folder_paths.get_folder_paths("test_folder")
|
||||
|
||||
def test_recursive_search(temp_dir):
|
||||
os.makedirs(os.path.join(temp_dir, "subdir"))
|
||||
open(os.path.join(temp_dir, "file1.txt"), "w").close()
|
||||
open(os.path.join(temp_dir, "subdir", "file2.txt"), "w").close()
|
||||
|
||||
files, dirs = folder_paths.recursive_search(temp_dir)
|
||||
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
|
||||
assert len(dirs) == 2 # temp_dir and subdir
|
||||
|
||||
def test_filter_files_extensions():
|
||||
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
|
||||
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
|
||||
assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"]
|
||||
assert folder_paths.filter_files_extensions(files, []) == files
|
||||
|
||||
@patch("folder_paths.recursive_search")
|
||||
@patch("folder_paths.folder_names_and_paths")
|
||||
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
|
||||
mock_folder_names_and_paths.__getitem__.return_value = (["/test/path"], {".txt"})
|
||||
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
|
||||
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
|
||||
|
||||
def test_get_save_image_path(temp_dir):
|
||||
with patch("folder_paths.output_directory", temp_dir):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
|
||||
assert os.path.samefile(full_output_folder, temp_dir)
|
||||
assert filename == "test"
|
||||
assert counter == 1
|
||||
assert subfolder == ""
|
||||
assert filename_prefix == "test"
|
||||
0
tests-unit/folder_paths_test/__init__.py
Normal file
0
tests-unit/folder_paths_test/__init__.py
Normal file
52
tests-unit/folder_paths_test/filter_by_content_types_test.py
Normal file
52
tests-unit/folder_paths_test/filter_by_content_types_test.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from folder_paths import filter_files_content_types
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def file_extensions():
|
||||
return {
|
||||
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
|
||||
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
|
||||
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_dir(file_extensions):
|
||||
with tempfile.TemporaryDirectory() as directory:
|
||||
for content_type, extensions in file_extensions.items():
|
||||
for extension in extensions:
|
||||
with open(f"{directory}/sample_{content_type}.{extension}", "w") as f:
|
||||
f.write(f"Sample {content_type} file in {extension} format")
|
||||
yield directory
|
||||
|
||||
|
||||
def test_categorizes_all_correctly(mock_dir, file_extensions):
|
||||
files = os.listdir(mock_dir)
|
||||
for content_type, extensions in file_extensions.items():
|
||||
filtered_files = filter_files_content_types(files, [content_type])
|
||||
for extension in extensions:
|
||||
assert f"sample_{content_type}.{extension}" in filtered_files
|
||||
|
||||
|
||||
def test_categorizes_all_uniquely(mock_dir, file_extensions):
|
||||
files = os.listdir(mock_dir)
|
||||
for content_type, extensions in file_extensions.items():
|
||||
filtered_files = filter_files_content_types(files, [content_type])
|
||||
assert len(filtered_files) == len(extensions)
|
||||
|
||||
|
||||
def test_handles_bad_extensions():
|
||||
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
|
||||
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||
|
||||
|
||||
def test_handles_no_extension():
|
||||
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
|
||||
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||
|
||||
|
||||
def test_handles_no_files():
|
||||
files = []
|
||||
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||
@@ -1,10 +1,17 @@
|
||||
import pytest
|
||||
import tempfile
|
||||
import aiohttp
|
||||
from aiohttp import ClientResponse
|
||||
import itertools
|
||||
import os
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
||||
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
||||
import folder_paths
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
yield tmpdirname
|
||||
|
||||
class AsyncIteratorMock:
|
||||
"""
|
||||
@@ -42,7 +49,7 @@ class ContentMock:
|
||||
return AsyncIteratorMock(self.chunks)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_success():
|
||||
async def test_download_model_success(temp_dir):
|
||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.status = 200
|
||||
mock_response.headers = {'Content-Length': '1000'}
|
||||
@@ -53,15 +60,13 @@ async def test_download_model_success():
|
||||
mock_make_request = AsyncMock(return_value=mock_response)
|
||||
mock_progress_callback = AsyncMock()
|
||||
|
||||
# Mock file operations
|
||||
mock_open = MagicMock()
|
||||
mock_file = MagicMock()
|
||||
mock_open.return_value.__enter__.return_value = mock_file
|
||||
time_values = itertools.count(0, 0.1)
|
||||
|
||||
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
|
||||
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
||||
|
||||
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
|
||||
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||
patch('builtins.open', mock_open), \
|
||||
patch('folder_paths.folder_names_and_paths', fake_paths), \
|
||||
patch('time.time', side_effect=time_values): # Simulate time passing
|
||||
|
||||
result = await download_model(
|
||||
@@ -69,6 +74,7 @@ async def test_download_model_success():
|
||||
'model.sft',
|
||||
'http://example.com/model.sft',
|
||||
'checkpoints',
|
||||
temp_dir,
|
||||
mock_progress_callback
|
||||
)
|
||||
|
||||
@@ -83,44 +89,48 @@ async def test_download_model_success():
|
||||
|
||||
# Check initial call
|
||||
mock_progress_callback.assert_any_call(
|
||||
'checkpoints/model.sft',
|
||||
'model.sft',
|
||||
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
||||
)
|
||||
|
||||
# Check final call
|
||||
mock_progress_callback.assert_any_call(
|
||||
'checkpoints/model.sft',
|
||||
'model.sft',
|
||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
||||
)
|
||||
|
||||
# Verify file writing
|
||||
mock_file.write.assert_any_call(b'a' * 500)
|
||||
mock_file.write.assert_any_call(b'b' * 300)
|
||||
mock_file.write.assert_any_call(b'c' * 200)
|
||||
mock_file_path = os.path.join(temp_dir, 'model.sft')
|
||||
assert os.path.exists(mock_file_path)
|
||||
with open(mock_file_path, 'rb') as mock_file:
|
||||
assert mock_file.read() == b''.join(chunks)
|
||||
os.remove(mock_file_path)
|
||||
|
||||
# Verify request was made
|
||||
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_url_request_failure():
|
||||
async def test_download_model_url_request_failure(temp_dir):
|
||||
# Mock dependencies
|
||||
mock_response = AsyncMock(spec=ClientResponse)
|
||||
mock_response.status = 404 # Simulate a "Not Found" error
|
||||
mock_get = AsyncMock(return_value=mock_response)
|
||||
mock_progress_callback = AsyncMock()
|
||||
|
||||
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
|
||||
|
||||
# Mock the create_model_path function
|
||||
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
|
||||
# Mock the check_file_exists function to return None (file doesn't exist)
|
||||
with patch('model_filemanager.check_file_exists', return_value=None):
|
||||
# Call the function
|
||||
result = await download_model(
|
||||
mock_get,
|
||||
'model.safetensors',
|
||||
'http://example.com/model.safetensors',
|
||||
'mock_directory',
|
||||
mock_progress_callback
|
||||
)
|
||||
with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
|
||||
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||
patch('folder_paths.folder_names_and_paths', fake_paths):
|
||||
# Call the function
|
||||
result = await download_model(
|
||||
mock_get,
|
||||
'model.safetensors',
|
||||
'http://example.com/model.safetensors',
|
||||
'checkpoints',
|
||||
temp_dir,
|
||||
mock_progress_callback
|
||||
)
|
||||
|
||||
# Assert the expected behavior
|
||||
assert isinstance(result, DownloadModelStatus)
|
||||
@@ -130,7 +140,7 @@ async def test_download_model_url_request_failure():
|
||||
|
||||
# Check that progress_callback was called with the correct arguments
|
||||
mock_progress_callback.assert_any_call(
|
||||
'mock_directory/model.safetensors',
|
||||
'model.safetensors',
|
||||
DownloadModelStatus(
|
||||
status=DownloadStatusType.PENDING,
|
||||
progress_percentage=0,
|
||||
@@ -139,7 +149,7 @@ async def test_download_model_url_request_failure():
|
||||
)
|
||||
)
|
||||
mock_progress_callback.assert_called_with(
|
||||
'mock_directory/model.safetensors',
|
||||
'model.safetensors',
|
||||
DownloadModelStatus(
|
||||
status=DownloadStatusType.ERROR,
|
||||
progress_percentage=0,
|
||||
@@ -153,98 +163,125 @@ async def test_download_model_url_request_failure():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_invalid_model_subdirectory():
|
||||
|
||||
mock_make_request = AsyncMock()
|
||||
mock_progress_callback = AsyncMock()
|
||||
|
||||
|
||||
result = await download_model(
|
||||
mock_make_request,
|
||||
'model.sft',
|
||||
'http://example.com/model.sft',
|
||||
'../bad_path',
|
||||
'../bad_path',
|
||||
mock_progress_callback
|
||||
)
|
||||
|
||||
# Assert the result
|
||||
assert isinstance(result, DownloadModelStatus)
|
||||
assert result.message == 'Invalid model subdirectory'
|
||||
assert result.message.startswith('Invalid or unrecognized model directory')
|
||||
assert result.status == 'error'
|
||||
assert result.already_existed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_invalid_folder_path():
|
||||
mock_make_request = AsyncMock()
|
||||
mock_progress_callback = AsyncMock()
|
||||
|
||||
result = await download_model(
|
||||
mock_make_request,
|
||||
'model.sft',
|
||||
'http://example.com/model.sft',
|
||||
'checkpoints',
|
||||
'invalid_path',
|
||||
mock_progress_callback
|
||||
)
|
||||
|
||||
# Assert the result
|
||||
assert isinstance(result, DownloadModelStatus)
|
||||
assert result.message.startswith("Invalid folder path")
|
||||
assert result.status == 'error'
|
||||
assert result.already_existed is False
|
||||
|
||||
# For create_model_path function
|
||||
def test_create_model_path(tmp_path, monkeypatch):
|
||||
mock_models_dir = tmp_path / "models"
|
||||
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
|
||||
|
||||
model_name = "test_model.sft"
|
||||
model_directory = "test_dir"
|
||||
|
||||
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
|
||||
|
||||
assert file_path == str(mock_models_dir / model_directory / model_name)
|
||||
assert relative_path == f"{model_directory}/{model_name}"
|
||||
model_name = "model.safetensors"
|
||||
folder_path = os.path.join(tmp_path, "mock_dir")
|
||||
|
||||
file_path = create_model_path(model_name, folder_path)
|
||||
|
||||
assert file_path == os.path.join(folder_path, "model.safetensors")
|
||||
assert os.path.exists(os.path.dirname(file_path))
|
||||
|
||||
with pytest.raises(Exception, match="Invalid model directory"):
|
||||
create_model_path("../path_traversal.safetensors", folder_path)
|
||||
|
||||
with pytest.raises(Exception, match="Invalid model directory"):
|
||||
create_model_path("/etc/some_root_path", folder_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_file_exists_when_file_exists(tmp_path):
|
||||
file_path = tmp_path / "existing_model.sft"
|
||||
file_path.touch() # Create an empty file
|
||||
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
|
||||
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
|
||||
|
||||
|
||||
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == "completed"
|
||||
assert result.message == "existing_model.sft already exists"
|
||||
assert result.already_existed is True
|
||||
|
||||
|
||||
mock_callback.assert_called_once_with(
|
||||
"test/existing_model.sft",
|
||||
"existing_model.sft",
|
||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
||||
file_path = tmp_path / "non_existing_model.sft"
|
||||
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
|
||||
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
|
||||
|
||||
|
||||
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback)
|
||||
|
||||
assert result is None
|
||||
mock_callback.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_download_progress_no_content_length():
|
||||
async def test_track_download_progress_no_content_length(temp_dir):
|
||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.headers = {} # No Content-Length header
|
||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
|
||||
chunks = [b'a' * 500, b'b' * 500]
|
||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
mock_open = MagicMock(return_value=MagicMock())
|
||||
|
||||
with patch('builtins.open', mock_open):
|
||||
result = await track_download_progress(
|
||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
||||
mock_callback, 'models/model.sft', interval=0.1
|
||||
)
|
||||
full_path = os.path.join(temp_dir, 'model.sft')
|
||||
|
||||
result = await track_download_progress(
|
||||
mock_response, full_path, 'model.sft',
|
||||
mock_callback, interval=0.1
|
||||
)
|
||||
|
||||
assert result.status == "completed"
|
||||
|
||||
assert os.path.exists(full_path)
|
||||
with open(full_path, 'rb') as f:
|
||||
assert f.read() == b''.join(chunks)
|
||||
os.remove(full_path)
|
||||
|
||||
# Check that progress was reported even without knowing the total size
|
||||
mock_callback.assert_any_call(
|
||||
'models/model.sft',
|
||||
'model.sft',
|
||||
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_download_progress_interval():
|
||||
async def test_track_download_progress_interval(temp_dir):
|
||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.headers = {'Content-Length': '1000'}
|
||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
|
||||
chunks = [b'a' * 100] * 10
|
||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
mock_open = MagicMock(return_value=MagicMock())
|
||||
@@ -253,18 +290,18 @@ async def test_track_download_progress_interval():
|
||||
mock_time = MagicMock()
|
||||
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
||||
|
||||
with patch('builtins.open', mock_open), \
|
||||
patch('time.time', mock_time):
|
||||
await track_download_progress(
|
||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
||||
mock_callback, 'models/model.sft', interval=1.0
|
||||
)
|
||||
full_path = os.path.join(temp_dir, 'model.sft')
|
||||
|
||||
# Print out the actual call count and the arguments of each call for debugging
|
||||
print(f"mock_callback was called {mock_callback.call_count} times")
|
||||
for i, call in enumerate(mock_callback.call_args_list):
|
||||
args, kwargs = call
|
||||
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
|
||||
with patch('time.time', mock_time):
|
||||
await track_download_progress(
|
||||
mock_response, full_path, 'model.sft',
|
||||
mock_callback, interval=1.0
|
||||
)
|
||||
|
||||
assert os.path.exists(full_path)
|
||||
with open(full_path, 'rb') as f:
|
||||
assert f.read() == b''.join(chunks)
|
||||
os.remove(full_path)
|
||||
|
||||
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
||||
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
||||
@@ -279,27 +316,6 @@ async def test_track_download_progress_interval():
|
||||
assert last_call[0][1].status == "completed"
|
||||
assert last_call[0][1].progress_percentage == 100
|
||||
|
||||
def test_valid_subdirectory():
|
||||
assert validate_model_subdirectory("valid-model123") is True
|
||||
|
||||
def test_subdirectory_too_long():
|
||||
assert validate_model_subdirectory("a" * 51) is False
|
||||
|
||||
def test_subdirectory_with_double_dots():
|
||||
assert validate_model_subdirectory("model/../unsafe") is False
|
||||
|
||||
def test_subdirectory_with_slash():
|
||||
assert validate_model_subdirectory("model/unsafe") is False
|
||||
|
||||
def test_subdirectory_with_special_characters():
|
||||
assert validate_model_subdirectory("model@unsafe") is False
|
||||
|
||||
def test_subdirectory_with_underscore_and_dash():
|
||||
assert validate_model_subdirectory("valid_model-name") is True
|
||||
|
||||
def test_empty_subdirectory():
|
||||
assert validate_model_subdirectory("") is False
|
||||
|
||||
@pytest.mark.parametrize("filename, expected", [
|
||||
("valid_model.safetensors", True),
|
||||
("valid_model.sft", True),
|
||||
|
||||
120
tests-unit/prompt_server_test/user_manager_test.py
Normal file
120
tests-unit/prompt_server_test/user_manager_test.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import pytest
|
||||
import os
|
||||
from aiohttp import web
|
||||
from app.user_manager import UserManager
|
||||
from unittest.mock import patch
|
||||
|
||||
pytestmark = (
|
||||
pytest.mark.asyncio
|
||||
) # This applies the asyncio mark to all test functions in the module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_manager(tmp_path):
|
||||
um = UserManager()
|
||||
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
|
||||
tmp_path, file
|
||||
)
|
||||
return um
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(user_manager):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
user_manager.add_routes(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
|
||||
async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=test_dir")
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
|
||||
os.makedirs(tmp_path / "test_dir")
|
||||
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=test_dir")
|
||||
assert resp.status == 200
|
||||
assert await resp.json() == ["file1.txt"]
|
||||
|
||||
|
||||
async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
|
||||
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||
f.write("test content")
|
||||
with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=test_dir&recurse=true")
|
||||
assert resp.status == 200
|
||||
assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"}
|
||||
|
||||
|
||||
async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
|
||||
os.makedirs(tmp_path / "test_dir")
|
||||
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=test_dir&full_info=true")
|
||||
assert resp.status == 200
|
||||
result = await resp.json()
|
||||
assert len(result) == 1
|
||||
assert result[0]["path"] == "file1.txt"
|
||||
assert "size" in result[0]
|
||||
assert "modified" in result[0]
|
||||
|
||||
|
||||
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
|
||||
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
|
||||
assert resp.status == 200
|
||||
assert await resp.json() == [
|
||||
["subdir/file1.txt", "subdir", "file1.txt"]
|
||||
]
|
||||
|
||||
|
||||
async def test_listuserdata_invalid_directory(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=")
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
|
||||
os_sep = "\\"
|
||||
with patch("os.sep", os_sep):
|
||||
with patch("os.path.sep", os_sep):
|
||||
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=test_dir&recurse=true")
|
||||
assert resp.status == 200
|
||||
result = await resp.json()
|
||||
assert len(result) == 1
|
||||
assert "/" in result[0] # Ensure forward slash is used
|
||||
assert "\\" not in result[0] # Ensure backslash is not present
|
||||
assert result[0] == "subdir/file1.txt"
|
||||
|
||||
# Test with full_info
|
||||
resp = await client.get(
|
||||
"/userdata?dir=test_dir&recurse=true&full_info=true"
|
||||
)
|
||||
assert resp.status == 200
|
||||
result = await resp.json()
|
||||
assert len(result) == 1
|
||||
assert "/" in result[0]["path"] # Ensure forward slash is used
|
||||
assert "\\" not in result[0]["path"] # Ensure backslash is not present
|
||||
assert result[0]["path"] == "subdir/file1.txt"
|
||||
126
tests-unit/utils/extra_config_test.py
Normal file
126
tests-unit/utils/extra_config_test.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import pytest
|
||||
import yaml
|
||||
import os
|
||||
from unittest.mock import Mock, patch, mock_open
|
||||
|
||||
from utils.extra_config import load_extra_path_config
|
||||
import folder_paths
|
||||
|
||||
@pytest.fixture
|
||||
def mock_yaml_content():
|
||||
return {
|
||||
'test_config': {
|
||||
'base_path': '~/App/',
|
||||
'checkpoints': 'subfolder1',
|
||||
}
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_expanded_home():
|
||||
return '/home/user'
|
||||
|
||||
@pytest.fixture
|
||||
def yaml_config_with_appdata():
|
||||
return """
|
||||
test_config:
|
||||
base_path: '%APPDATA%/ComfyUI'
|
||||
checkpoints: 'models/checkpoints'
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_yaml_content_appdata(yaml_config_with_appdata):
|
||||
return yaml.safe_load(yaml_config_with_appdata)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_expandvars_appdata():
|
||||
mock = Mock()
|
||||
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_add_model_folder_path():
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_expanduser(mock_expanded_home):
|
||||
def _expanduser(path):
|
||||
if path.startswith('~/'):
|
||||
return os.path.join(mock_expanded_home, path[2:])
|
||||
return path
|
||||
return _expanduser
|
||||
|
||||
@pytest.fixture
|
||||
def mock_yaml_safe_load(mock_yaml_content):
|
||||
return Mock(return_value=mock_yaml_content)
|
||||
|
||||
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
|
||||
def test_load_extra_model_paths_expands_userpath(
|
||||
mock_file,
|
||||
monkeypatch,
|
||||
mock_add_model_folder_path,
|
||||
mock_expanduser,
|
||||
mock_yaml_safe_load,
|
||||
mock_expanded_home
|
||||
):
|
||||
# Attach mocks used by load_extra_path_config
|
||||
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
|
||||
monkeypatch.setattr(os.path, 'expanduser', mock_expanduser)
|
||||
monkeypatch.setattr(yaml, 'safe_load', mock_yaml_safe_load)
|
||||
|
||||
dummy_yaml_file_name = 'dummy_path.yaml'
|
||||
load_extra_path_config(dummy_yaml_file_name)
|
||||
|
||||
expected_calls = [
|
||||
('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1'), False),
|
||||
]
|
||||
|
||||
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||
|
||||
# Check if add_model_folder_path was called with the correct arguments
|
||||
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||
assert actual_call.args[0] == expected_call[0]
|
||||
assert os.path.normpath(actual_call.args[1]) == os.path.normpath(expected_call[1]) # Normalize and check the path to check on multiple OS.
|
||||
assert actual_call.args[2] == expected_call[2]
|
||||
|
||||
# Check if yaml.safe_load was called
|
||||
mock_yaml_safe_load.assert_called_once()
|
||||
|
||||
# Check if open was called with the correct file path
|
||||
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
|
||||
|
||||
@patch('builtins.open', new_callable=mock_open)
|
||||
def test_load_extra_model_paths_expands_appdata(
|
||||
mock_file,
|
||||
monkeypatch,
|
||||
mock_add_model_folder_path,
|
||||
mock_expandvars_appdata,
|
||||
yaml_config_with_appdata,
|
||||
mock_yaml_content_appdata
|
||||
):
|
||||
# Set the mock_file to return yaml with appdata as a variable
|
||||
mock_file.return_value.read.return_value = yaml_config_with_appdata
|
||||
|
||||
# Attach mocks
|
||||
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
|
||||
monkeypatch.setattr(os.path, 'expandvars', mock_expandvars_appdata)
|
||||
monkeypatch.setattr(yaml, 'safe_load', Mock(return_value=mock_yaml_content_appdata))
|
||||
|
||||
# Mock expanduser to do nothing (since we're not testing it here)
|
||||
monkeypatch.setattr(os.path, 'expanduser', lambda x: x)
|
||||
|
||||
dummy_yaml_file_name = 'dummy_path.yaml'
|
||||
load_extra_path_config(dummy_yaml_file_name)
|
||||
|
||||
expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
|
||||
expected_calls = [
|
||||
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False),
|
||||
]
|
||||
|
||||
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||
|
||||
# Check the base path variable was expanded
|
||||
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||
assert actual_call.args == expected_call
|
||||
|
||||
# Verify that expandvars was called
|
||||
assert mock_expandvars_appdata.called
|
||||
@@ -496,3 +496,29 @@ class TestExecution:
|
||||
assert len(images) == 1, "Should have 1 image"
|
||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||
assert not result.did_run(test_node), "The execution should have been cached"
|
||||
|
||||
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
|
||||
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
|
||||
# only that one entry in the list is blocked.
|
||||
def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
image3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image_list = g.node("TestMakeListNode", value1=image1.out(0), value2=image2.out(0), value3=image3.out(0))
|
||||
int1 = g.node("StubInt", value=1)
|
||||
int2 = g.node("StubInt", value=2)
|
||||
int3 = g.node("StubInt", value=3)
|
||||
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
|
||||
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
|
||||
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
||||
|
||||
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
|
||||
output = g.node("PreviewImage", images=list_output.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
assert result.did_run(output), "The execution should have run"
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 2, "Should have 2 images"
|
||||
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
|
||||
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
|
||||
|
||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
28
utils/extra_config.py
Normal file
28
utils/extra_config.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import os
|
||||
import yaml
|
||||
import folder_paths
|
||||
import logging
|
||||
|
||||
def load_extra_path_config(yaml_path):
|
||||
with open(yaml_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
for c in config:
|
||||
conf = config[c]
|
||||
if conf is None:
|
||||
continue
|
||||
base_path = None
|
||||
if "base_path" in conf:
|
||||
base_path = conf.pop("base_path")
|
||||
base_path = os.path.expandvars(os.path.expanduser(base_path))
|
||||
is_default = False
|
||||
if "is_default" in conf:
|
||||
is_default = conf.pop("is_default")
|
||||
for x in conf:
|
||||
for y in conf[x].split("\n"):
|
||||
if len(y) == 0:
|
||||
continue
|
||||
full_path = y
|
||||
if base_path is not None:
|
||||
full_path = os.path.join(base_path, full_path)
|
||||
logging.info("Adding extra search path {} {}".format(x, full_path))
|
||||
folder_paths.add_model_folder_path(x, full_path, is_default)
|
||||
1
web/assets/CREDIT.txt
generated
vendored
Normal file
1
web/assets/CREDIT.txt
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
Thanks to OpenArt (https://openart.ai) for providing the sorted-custom-node-map data, captured in September 2024.
|
||||
792
web/assets/GraphView-BGt8GmeB.css
generated
vendored
Normal file
792
web/assets/GraphView-BGt8GmeB.css
generated
vendored
Normal file
@@ -0,0 +1,792 @@
|
||||
|
||||
.editable-text[data-v-54da6fc9] {
|
||||
display: inline;
|
||||
}
|
||||
.editable-text input[data-v-54da6fc9] {
|
||||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.group-title-editor.node-title-editor[data-v-fc3f26e3] {
|
||||
z-index: 9999;
|
||||
padding: 0.25rem;
|
||||
}
|
||||
[data-v-fc3f26e3] .editable-text {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
[data-v-fc3f26e3] .editable-text input {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
/* Override the default font size */
|
||||
font-size: inherit;
|
||||
}
|
||||
|
||||
.side-bar-button-icon {
|
||||
font-size: var(--sidebar-icon-size) !important;
|
||||
}
|
||||
.side-bar-button-selected .side-bar-button-icon {
|
||||
font-size: var(--sidebar-icon-size) !important;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.side-bar-button[data-v-caa3ee9c] {
|
||||
width: var(--sidebar-width);
|
||||
height: var(--sidebar-width);
|
||||
border-radius: 0;
|
||||
}
|
||||
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
|
||||
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
|
||||
border-left: 4px solid var(--p-button-text-primary-color);
|
||||
}
|
||||
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
|
||||
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
|
||||
border-right: 4px solid var(--p-button-text-primary-color);
|
||||
}
|
||||
|
||||
:root {
|
||||
--sidebar-width: 64px;
|
||||
--sidebar-icon-size: 1.5rem;
|
||||
}
|
||||
:root .small-sidebar {
|
||||
--sidebar-width: 40px;
|
||||
--sidebar-icon-size: 1rem;
|
||||
}
|
||||
|
||||
.side-tool-bar-container[data-v-4da64512] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
|
||||
pointer-events: auto;
|
||||
|
||||
width: var(--sidebar-width);
|
||||
height: 100%;
|
||||
|
||||
background-color: var(--comfy-menu-bg);
|
||||
color: var(--fg-color);
|
||||
}
|
||||
.side-tool-bar-end[data-v-4da64512] {
|
||||
align-self: flex-end;
|
||||
margin-top: auto;
|
||||
}
|
||||
.sidebar-content-container[data-v-4da64512] {
|
||||
height: 100%;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.p-splitter-gutter {
|
||||
pointer-events: auto;
|
||||
}
|
||||
.gutter-hidden {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
.side-bar-panel[data-v-b9df3042] {
|
||||
background-color: var(--bg-color);
|
||||
pointer-events: auto;
|
||||
}
|
||||
.splitter-overlay[data-v-b9df3042] {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
background-color: transparent;
|
||||
pointer-events: none;
|
||||
/* Set it the same as the ComfyUI menu */
|
||||
/* Note: Lite-graph DOM widgets have the same z-index as the node id, so
|
||||
999 should be sufficient to make sure splitter overlays on node's DOM
|
||||
widgets */
|
||||
z-index: 999;
|
||||
border: none;
|
||||
}
|
||||
|
||||
._content[data-v-e7b35fd9] {
|
||||
|
||||
display: flex;
|
||||
|
||||
flex-direction: column
|
||||
}
|
||||
._content[data-v-e7b35fd9] > :not([hidden]) ~ :not([hidden]) {
|
||||
|
||||
--tw-space-y-reverse: 0;
|
||||
|
||||
margin-top: calc(0.5rem * calc(1 - var(--tw-space-y-reverse)));
|
||||
|
||||
margin-bottom: calc(0.5rem * var(--tw-space-y-reverse))
|
||||
}
|
||||
._footer[data-v-e7b35fd9] {
|
||||
|
||||
display: flex;
|
||||
|
||||
flex-direction: column;
|
||||
|
||||
align-items: flex-end;
|
||||
|
||||
padding-top: 1rem
|
||||
}
|
||||
|
||||
[data-v-37f672ab] .highlight {
|
||||
background-color: var(--p-primary-color);
|
||||
color: var(--p-primary-contrast-color);
|
||||
font-weight: bold;
|
||||
border-radius: 0.25rem;
|
||||
padding: 0rem 0.125rem;
|
||||
margin: -0.125rem 0.125rem;
|
||||
}
|
||||
|
||||
.slot_row[data-v-ff07c900] {
|
||||
padding: 2px;
|
||||
}
|
||||
|
||||
/* Original N-Sidebar styles */
|
||||
._sb_dot[data-v-ff07c900] {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
background-color: grey;
|
||||
}
|
||||
.node_header[data-v-ff07c900] {
|
||||
line-height: 1;
|
||||
padding: 8px 13px 7px;
|
||||
margin-bottom: 5px;
|
||||
font-size: 15px;
|
||||
text-wrap: nowrap;
|
||||
overflow: hidden;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
.headdot[data-v-ff07c900] {
|
||||
width: 10px;
|
||||
height: 10px;
|
||||
float: inline-start;
|
||||
margin-right: 8px;
|
||||
}
|
||||
.IMAGE[data-v-ff07c900] {
|
||||
background-color: #64b5f6;
|
||||
}
|
||||
.VAE[data-v-ff07c900] {
|
||||
background-color: #ff6e6e;
|
||||
}
|
||||
.LATENT[data-v-ff07c900] {
|
||||
background-color: #ff9cf9;
|
||||
}
|
||||
.MASK[data-v-ff07c900] {
|
||||
background-color: #81c784;
|
||||
}
|
||||
.CONDITIONING[data-v-ff07c900] {
|
||||
background-color: #ffa931;
|
||||
}
|
||||
.CLIP[data-v-ff07c900] {
|
||||
background-color: #ffd500;
|
||||
}
|
||||
.MODEL[data-v-ff07c900] {
|
||||
background-color: #b39ddb;
|
||||
}
|
||||
.CONTROL_NET[data-v-ff07c900] {
|
||||
background-color: #a5d6a7;
|
||||
}
|
||||
._sb_node_preview[data-v-ff07c900] {
|
||||
background-color: var(--comfy-menu-bg);
|
||||
font-family: 'Open Sans', sans-serif;
|
||||
font-size: small;
|
||||
color: var(--descrip-text);
|
||||
border: 1px solid var(--descrip-text);
|
||||
min-width: 300px;
|
||||
width: -moz-min-content;
|
||||
width: min-content;
|
||||
height: -moz-fit-content;
|
||||
height: fit-content;
|
||||
z-index: 9999;
|
||||
border-radius: 12px;
|
||||
overflow: hidden;
|
||||
font-size: 12px;
|
||||
padding-bottom: 10px;
|
||||
}
|
||||
._sb_node_preview ._sb_description[data-v-ff07c900] {
|
||||
margin: 10px;
|
||||
padding: 6px;
|
||||
background: var(--border-color);
|
||||
border-radius: 5px;
|
||||
font-style: italic;
|
||||
font-weight: 500;
|
||||
font-size: 0.9rem;
|
||||
word-break: break-word;
|
||||
}
|
||||
._sb_table[data-v-ff07c900] {
|
||||
display: grid;
|
||||
|
||||
grid-column-gap: 10px;
|
||||
/* Spazio tra le colonne */
|
||||
width: 100%;
|
||||
/* Imposta la larghezza della tabella al 100% del contenitore */
|
||||
}
|
||||
._sb_row[data-v-ff07c900] {
|
||||
display: grid;
|
||||
grid-template-columns: 10px 1fr 1fr 1fr 10px;
|
||||
grid-column-gap: 10px;
|
||||
align-items: center;
|
||||
padding-left: 9px;
|
||||
padding-right: 9px;
|
||||
}
|
||||
._sb_row_string[data-v-ff07c900] {
|
||||
grid-template-columns: 10px 1fr 1fr 10fr 1fr;
|
||||
}
|
||||
._sb_col[data-v-ff07c900] {
|
||||
border: 0px solid #000;
|
||||
display: flex;
|
||||
align-items: flex-end;
|
||||
flex-direction: row-reverse;
|
||||
flex-wrap: nowrap;
|
||||
align-content: flex-start;
|
||||
justify-content: flex-end;
|
||||
}
|
||||
._sb_inherit[data-v-ff07c900] {
|
||||
display: inherit;
|
||||
}
|
||||
._long_field[data-v-ff07c900] {
|
||||
background: var(--bg-color);
|
||||
border: 2px solid var(--border-color);
|
||||
margin: 5px 5px 0 5px;
|
||||
border-radius: 10px;
|
||||
line-height: 1.7;
|
||||
text-wrap: nowrap;
|
||||
}
|
||||
._sb_arrow[data-v-ff07c900] {
|
||||
color: var(--fg-color);
|
||||
}
|
||||
._sb_preview_badge[data-v-ff07c900] {
|
||||
text-align: center;
|
||||
background: var(--comfy-input-bg);
|
||||
font-weight: bold;
|
||||
color: var(--error-text);
|
||||
}
|
||||
|
||||
.comfy-vue-node-search-container[data-v-2d409367] {
|
||||
display: flex;
|
||||
width: 100%;
|
||||
min-width: 26rem;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
.comfy-vue-node-search-container[data-v-2d409367] * {
|
||||
pointer-events: auto;
|
||||
}
|
||||
.comfy-vue-node-preview-container[data-v-2d409367] {
|
||||
position: absolute;
|
||||
left: -350px;
|
||||
top: 50px;
|
||||
}
|
||||
.comfy-vue-node-search-box[data-v-2d409367] {
|
||||
z-index: 10;
|
||||
flex-grow: 1;
|
||||
}
|
||||
._filter-button[data-v-2d409367] {
|
||||
z-index: 10;
|
||||
}
|
||||
._dialog[data-v-2d409367] {
|
||||
min-width: 26rem;
|
||||
}
|
||||
|
||||
.invisible-dialog-root {
|
||||
width: 60%;
|
||||
min-width: 24rem;
|
||||
max-width: 48rem;
|
||||
border: 0 !important;
|
||||
background-color: transparent !important;
|
||||
margin-top: 25vh;
|
||||
margin-left: 400px;
|
||||
}
|
||||
@media all and (max-width: 768px) {
|
||||
.invisible-dialog-root {
|
||||
margin-left: 0px;
|
||||
}
|
||||
}
|
||||
.node-search-box-dialog-mask {
|
||||
align-items: flex-start !important;
|
||||
}
|
||||
|
||||
.node-tooltip[data-v-0a4402f9] {
|
||||
background: var(--comfy-input-bg);
|
||||
border-radius: 5px;
|
||||
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
|
||||
color: var(--input-text);
|
||||
font-family: sans-serif;
|
||||
left: 0;
|
||||
max-width: 30vw;
|
||||
padding: 4px 8px;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
transform: translate(5px, calc(-100% - 5px));
|
||||
white-space: pre-wrap;
|
||||
z-index: 99999;
|
||||
}
|
||||
|
||||
.p-buttongroup-vertical[data-v-ce8bd6ac] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
border-radius: var(--p-button-border-radius);
|
||||
overflow: hidden;
|
||||
border: 1px solid var(--p-panel-border-color);
|
||||
}
|
||||
.p-buttongroup-vertical .p-button[data-v-ce8bd6ac] {
|
||||
margin: 0;
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
.comfy-image-wrap[data-v-9bc23daf] {
|
||||
display: contents;
|
||||
}
|
||||
.comfy-image-blur[data-v-9bc23daf] {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
-o-object-fit: cover;
|
||||
object-fit: cover;
|
||||
}
|
||||
.comfy-image-main[data-v-9bc23daf] {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
-o-object-fit: cover;
|
||||
object-fit: cover;
|
||||
-o-object-position: center;
|
||||
object-position: center;
|
||||
z-index: 1;
|
||||
}
|
||||
.contain .comfy-image-wrap[data-v-9bc23daf] {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
.contain .comfy-image-main[data-v-9bc23daf] {
|
||||
-o-object-fit: contain;
|
||||
object-fit: contain;
|
||||
-webkit-backdrop-filter: blur(10px);
|
||||
backdrop-filter: blur(10px);
|
||||
position: absolute;
|
||||
}
|
||||
.broken-image-placeholder[data-v-9bc23daf] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
margin: 2rem;
|
||||
}
|
||||
.broken-image-placeholder i[data-v-9bc23daf] {
|
||||
font-size: 3rem;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.result-container[data-v-d9c060ae] {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
aspect-ratio: 1 / 1;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
.image-preview-mask[data-v-d9c060ae] {
|
||||
position: absolute;
|
||||
left: 50%;
|
||||
top: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
opacity: 0;
|
||||
transition: opacity 0.3s ease;
|
||||
z-index: 1;
|
||||
}
|
||||
.result-container:hover .image-preview-mask[data-v-d9c060ae] {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.task-result-preview[data-v-d4c8a1fe] {
|
||||
aspect-ratio: 1 / 1;
|
||||
overflow: hidden;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
.task-result-preview i[data-v-d4c8a1fe],
|
||||
.task-result-preview span[data-v-d4c8a1fe] {
|
||||
font-size: 2rem;
|
||||
}
|
||||
.task-item[data-v-d4c8a1fe] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
.task-item-details[data-v-d4c8a1fe] {
|
||||
position: absolute;
|
||||
bottom: 0;
|
||||
padding: 0.6rem;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
z-index: 1;
|
||||
}
|
||||
.task-node-link[data-v-d4c8a1fe] {
|
||||
padding: 2px;
|
||||
}
|
||||
|
||||
/* In dark mode, transparent background color for tags is not ideal for tags that
|
||||
are floating on top of images. */
|
||||
.tag-wrapper[data-v-d4c8a1fe] {
|
||||
background-color: var(--p-primary-contrast-color);
|
||||
border-radius: 6px;
|
||||
display: inline-flex;
|
||||
}
|
||||
.node-name-tag[data-v-d4c8a1fe] {
|
||||
word-break: break-all;
|
||||
}
|
||||
.status-tag-group[data-v-d4c8a1fe] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
.progress-preview-img[data-v-d4c8a1fe] {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
-o-object-fit: cover;
|
||||
object-fit: cover;
|
||||
-o-object-position: center;
|
||||
object-position: center;
|
||||
}
|
||||
|
||||
/* PrimeVue's galleria teleports the fullscreen gallery out of subtree so we
|
||||
cannot use scoped style here. */
|
||||
img.galleria-image {
|
||||
max-width: 100vw;
|
||||
max-height: 100vh;
|
||||
-o-object-fit: contain;
|
||||
object-fit: contain;
|
||||
}
|
||||
.p-galleria-close-button {
|
||||
/* Set z-index so the close button doesn't get hidden behind the image when image is large */
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.comfy-vue-side-bar-container[data-v-1b0a8fe3] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
height: 100%;
|
||||
overflow: hidden;
|
||||
}
|
||||
.comfy-vue-side-bar-header[data-v-1b0a8fe3] {
|
||||
flex-shrink: 0;
|
||||
border-left: none;
|
||||
border-right: none;
|
||||
border-top: none;
|
||||
border-radius: 0;
|
||||
padding: 0.25rem 1rem;
|
||||
min-height: 2.5rem;
|
||||
}
|
||||
.comfy-vue-side-bar-header-span[data-v-1b0a8fe3] {
|
||||
font-size: small;
|
||||
}
|
||||
.comfy-vue-side-bar-body[data-v-1b0a8fe3] {
|
||||
flex-grow: 1;
|
||||
overflow: auto;
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: transparent transparent;
|
||||
}
|
||||
.comfy-vue-side-bar-body[data-v-1b0a8fe3]::-webkit-scrollbar {
|
||||
width: 1px;
|
||||
}
|
||||
.comfy-vue-side-bar-body[data-v-1b0a8fe3]::-webkit-scrollbar-thumb {
|
||||
background-color: transparent;
|
||||
}
|
||||
|
||||
.scroll-container[data-v-08fa89b1] {
|
||||
height: 100%;
|
||||
overflow-y: auto;
|
||||
}
|
||||
.queue-grid[data-v-08fa89b1] {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
|
||||
padding: 0.5rem;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.tree-node[data-v-633e27ab] {
|
||||
width: 100%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
}
|
||||
.leaf-count-badge[data-v-633e27ab] {
|
||||
margin-left: 0.5rem;
|
||||
}
|
||||
.node-content[data-v-633e27ab] {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
flex-grow: 1;
|
||||
}
|
||||
.leaf-label[data-v-633e27ab] {
|
||||
margin-left: 0.5rem;
|
||||
}
|
||||
[data-v-633e27ab] .editable-text span {
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
[data-v-bd7bae90] .tree-explorer-node-label {
|
||||
width: 100%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-left: var(--p-tree-node-gap);
|
||||
flex-grow: 1;
|
||||
}
|
||||
|
||||
/*
|
||||
* The following styles are necessary to avoid layout shift when dragging nodes over folders.
|
||||
* By setting the position to relative on the parent and using an absolutely positioned pseudo-element,
|
||||
* we can create a visual indicator for the drop target without affecting the layout of other elements.
|
||||
*/
|
||||
[data-v-bd7bae90] .p-tree-node-content:has(.tree-folder) {
|
||||
position: relative;
|
||||
}
|
||||
[data-v-bd7bae90] .p-tree-node-content:has(.tree-folder.can-drop)::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
border: 1px solid var(--p-content-color);
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.node-lib-node-container[data-v-90dfee08] {
|
||||
height: 100%;
|
||||
width: 100%
|
||||
}
|
||||
|
||||
.p-selectbutton .p-button[data-v-91077f2a] {
|
||||
padding: 0.5rem;
|
||||
}
|
||||
.p-selectbutton .p-button .pi[data-v-91077f2a] {
|
||||
font-size: 1.5rem;
|
||||
}
|
||||
.field[data-v-91077f2a] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
.color-picker-container[data-v-91077f2a] {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.node-lib-filter-popup {
|
||||
margin-left: -13px;
|
||||
}
|
||||
|
||||
[data-v-f6a7371a] .comfy-vue-side-bar-body {
|
||||
background: var(--p-tree-background);
|
||||
}
|
||||
[data-v-f6a7371a] .node-lib-bookmark-tree-explorer {
|
||||
padding-bottom: 2px;
|
||||
}
|
||||
[data-v-f6a7371a] .p-divider {
|
||||
margin: var(--comfy-tree-explorer-item-padding) 0px;
|
||||
}
|
||||
|
||||
.model_preview[data-v-32e6c4d9] {
|
||||
background-color: var(--comfy-menu-bg);
|
||||
font-family: 'Open Sans', sans-serif;
|
||||
color: var(--descrip-text);
|
||||
border: 1px solid var(--descrip-text);
|
||||
min-width: 300px;
|
||||
max-width: 500px;
|
||||
width: -moz-fit-content;
|
||||
width: fit-content;
|
||||
height: -moz-fit-content;
|
||||
height: fit-content;
|
||||
z-index: 9999;
|
||||
border-radius: 12px;
|
||||
overflow: hidden;
|
||||
font-size: 12px;
|
||||
padding: 10px;
|
||||
}
|
||||
.model_preview_image[data-v-32e6c4d9] {
|
||||
margin: auto;
|
||||
width: -moz-fit-content;
|
||||
width: fit-content;
|
||||
}
|
||||
.model_preview_image img[data-v-32e6c4d9] {
|
||||
max-width: 100%;
|
||||
max-height: 150px;
|
||||
-o-object-fit: contain;
|
||||
object-fit: contain;
|
||||
}
|
||||
.model_preview_title[data-v-32e6c4d9] {
|
||||
font-weight: bold;
|
||||
text-align: center;
|
||||
font-size: 14px;
|
||||
}
|
||||
.model_preview_top_container[data-v-32e6c4d9] {
|
||||
text-align: center;
|
||||
line-height: 0.5;
|
||||
}
|
||||
.model_preview_filename[data-v-32e6c4d9],
|
||||
.model_preview_author[data-v-32e6c4d9],
|
||||
.model_preview_architecture[data-v-32e6c4d9] {
|
||||
display: inline-block;
|
||||
text-align: center;
|
||||
margin: 5px;
|
||||
font-size: 10px;
|
||||
}
|
||||
.model_preview_prefix[data-v-32e6c4d9] {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.model-lib-model-icon-container[data-v-70b69131] {
|
||||
display: inline-block;
|
||||
position: relative;
|
||||
left: 0;
|
||||
height: 1.5rem;
|
||||
vertical-align: top;
|
||||
width: 0px;
|
||||
}
|
||||
.model-lib-model-icon[data-v-70b69131] {
|
||||
background-size: cover;
|
||||
background-position: center;
|
||||
display: inline-block;
|
||||
position: relative;
|
||||
left: -2.5rem;
|
||||
height: 2rem;
|
||||
width: 2rem;
|
||||
vertical-align: top;
|
||||
}
|
||||
|
||||
.pi-fake-spacer {
|
||||
height: 1px;
|
||||
width: 16px;
|
||||
}
|
||||
|
||||
[data-v-74b01bce] .comfy-vue-side-bar-body {
|
||||
background: var(--p-tree-background);
|
||||
}
|
||||
|
||||
[data-v-d2d58252] .comfy-vue-side-bar-body {
|
||||
background: var(--p-tree-background);
|
||||
}
|
||||
|
||||
[data-v-84e785b8] .p-togglebutton::before {
|
||||
display: none
|
||||
}
|
||||
[data-v-84e785b8] .p-togglebutton {
|
||||
position: relative;
|
||||
flex-shrink: 0;
|
||||
border-radius: 0px;
|
||||
background-color: transparent;
|
||||
padding-left: 0.5rem;
|
||||
padding-right: 0.5rem
|
||||
}
|
||||
[data-v-84e785b8] .p-togglebutton.p-togglebutton-checked {
|
||||
border-bottom-width: 2px;
|
||||
border-bottom-color: var(--p-button-text-primary-color)
|
||||
}
|
||||
[data-v-84e785b8] .p-togglebutton-checked .close-button,[data-v-84e785b8] .p-togglebutton:hover .close-button {
|
||||
visibility: visible
|
||||
}
|
||||
.status-indicator[data-v-84e785b8] {
|
||||
position: absolute;
|
||||
font-weight: 700;
|
||||
font-size: 1.5rem;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%)
|
||||
}
|
||||
[data-v-84e785b8] .p-togglebutton:hover .status-indicator {
|
||||
display: none
|
||||
}
|
||||
[data-v-84e785b8] .p-togglebutton .close-button {
|
||||
visibility: hidden
|
||||
}
|
||||
|
||||
.top-menubar[data-v-2ec1b620] .p-menubar-item-link svg {
|
||||
display: none;
|
||||
}
|
||||
[data-v-2ec1b620] .p-menubar-submenu.dropdown-direction-up {
|
||||
top: auto;
|
||||
bottom: 100%;
|
||||
flex-direction: column-reverse;
|
||||
}
|
||||
.keybinding-tag[data-v-2ec1b620] {
|
||||
background: var(--p-content-hover-background);
|
||||
border-color: var(--p-content-border-color);
|
||||
border-style: solid;
|
||||
}
|
||||
|
||||
[data-v-713442be] .p-inputtext {
|
||||
border-top-left-radius: 0;
|
||||
border-bottom-left-radius: 0;
|
||||
}
|
||||
|
||||
.comfyui-queue-button[data-v-fcd3efcd] .p-splitbutton-dropdown {
|
||||
border-top-right-radius: 0;
|
||||
border-bottom-right-radius: 0;
|
||||
}
|
||||
|
||||
.actionbar[data-v-bc6c78dd] {
|
||||
pointer-events: all;
|
||||
position: fixed;
|
||||
z-index: 1000;
|
||||
}
|
||||
.actionbar.is-docked[data-v-bc6c78dd] {
|
||||
position: static;
|
||||
border-style: none;
|
||||
background-color: transparent;
|
||||
padding: 0px;
|
||||
}
|
||||
.actionbar.is-dragging[data-v-bc6c78dd] {
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
user-select: none;
|
||||
}
|
||||
[data-v-bc6c78dd] .p-panel-content {
|
||||
padding: 0.25rem;
|
||||
}
|
||||
[data-v-bc6c78dd] .p-panel-header {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.comfyui-menu[data-v-b13fdc92] {
|
||||
width: 100vw;
|
||||
background: var(--comfy-menu-bg);
|
||||
color: var(--fg-color);
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
font-size: 0.8em;
|
||||
box-sizing: border-box;
|
||||
z-index: 1000;
|
||||
order: 0;
|
||||
grid-column: 1/-1;
|
||||
max-height: 90vh;
|
||||
}
|
||||
.comfyui-menu.dropzone[data-v-b13fdc92] {
|
||||
background: var(--p-highlight-background);
|
||||
}
|
||||
.comfyui-menu.dropzone-active[data-v-b13fdc92] {
|
||||
background: var(--p-highlight-background-focus);
|
||||
}
|
||||
.comfyui-logo[data-v-b13fdc92] {
|
||||
font-size: 1.2em;
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
user-select: none;
|
||||
cursor: default;
|
||||
}
|
||||
17465
web/assets/GraphView-CVV2XJjS.js
generated
vendored
Normal file
17465
web/assets/GraphView-CVV2XJjS.js
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
web/assets/GraphView-CVV2XJjS.js.map
generated
vendored
Normal file
1
web/assets/GraphView-CVV2XJjS.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
865
web/assets/colorPalette-D5oi2-2V.js
generated
vendored
Normal file
865
web/assets/colorPalette-D5oi2-2V.js
generated
vendored
Normal file
@@ -0,0 +1,865 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { k as app, aP as LGraphCanvas, bO as useToastStore, ca as $el, z as LiteGraph } from "./index-DGAbdBYF.js";
|
||||
const colorPalettes = {
|
||||
dark: {
|
||||
id: "dark",
|
||||
name: "Dark (Default)",
|
||||
colors: {
|
||||
node_slot: {
|
||||
CLIP: "#FFD500",
|
||||
// bright yellow
|
||||
CLIP_VISION: "#A8DADC",
|
||||
// light blue-gray
|
||||
CLIP_VISION_OUTPUT: "#ad7452",
|
||||
// rusty brown-orange
|
||||
CONDITIONING: "#FFA931",
|
||||
// vibrant orange-yellow
|
||||
CONTROL_NET: "#6EE7B7",
|
||||
// soft mint green
|
||||
IMAGE: "#64B5F6",
|
||||
// bright sky blue
|
||||
LATENT: "#FF9CF9",
|
||||
// light pink-purple
|
||||
MASK: "#81C784",
|
||||
// muted green
|
||||
MODEL: "#B39DDB",
|
||||
// light lavender-purple
|
||||
STYLE_MODEL: "#C2FFAE",
|
||||
// light green-yellow
|
||||
VAE: "#FF6E6E",
|
||||
// bright red
|
||||
NOISE: "#B0B0B0",
|
||||
// gray
|
||||
GUIDER: "#66FFFF",
|
||||
// cyan
|
||||
SAMPLER: "#ECB4B4",
|
||||
// very soft red
|
||||
SIGMAS: "#CDFFCD",
|
||||
// soft lime green
|
||||
TAESD: "#DCC274"
|
||||
// cheesecake
|
||||
},
|
||||
litegraph_base: {
|
||||
BACKGROUND_IMAGE: "",
|
||||
CLEAR_BACKGROUND_COLOR: "#222",
|
||||
NODE_TITLE_COLOR: "#999",
|
||||
NODE_SELECTED_TITLE_COLOR: "#FFF",
|
||||
NODE_TEXT_SIZE: 14,
|
||||
NODE_TEXT_COLOR: "#AAA",
|
||||
NODE_SUBTEXT_SIZE: 12,
|
||||
NODE_DEFAULT_COLOR: "#333",
|
||||
NODE_DEFAULT_BGCOLOR: "#353535",
|
||||
NODE_DEFAULT_BOXCOLOR: "#666",
|
||||
NODE_DEFAULT_SHAPE: "box",
|
||||
NODE_BOX_OUTLINE_COLOR: "#FFF",
|
||||
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||
DEFAULT_GROUP_FONT: 24,
|
||||
WIDGET_BGCOLOR: "#222",
|
||||
WIDGET_OUTLINE_COLOR: "#666",
|
||||
WIDGET_TEXT_COLOR: "#DDD",
|
||||
WIDGET_SECONDARY_TEXT_COLOR: "#999",
|
||||
LINK_COLOR: "#9A9",
|
||||
EVENT_LINK_COLOR: "#A86",
|
||||
CONNECTING_LINK_COLOR: "#AFA",
|
||||
BADGE_FG_COLOR: "#FFF",
|
||||
BADGE_BG_COLOR: "#0F1F0F"
|
||||
},
|
||||
comfy_base: {
|
||||
"fg-color": "#fff",
|
||||
"bg-color": "#202020",
|
||||
"comfy-menu-bg": "#353535",
|
||||
"comfy-input-bg": "#222",
|
||||
"input-text": "#ddd",
|
||||
"descrip-text": "#999",
|
||||
"drag-text": "#ccc",
|
||||
"error-text": "#ff4444",
|
||||
"border-color": "#4e4e4e",
|
||||
"tr-even-bg-color": "#222",
|
||||
"tr-odd-bg-color": "#353535",
|
||||
"content-bg": "#4e4e4e",
|
||||
"content-fg": "#fff",
|
||||
"content-hover-bg": "#222",
|
||||
"content-hover-fg": "#fff"
|
||||
}
|
||||
}
|
||||
},
|
||||
light: {
|
||||
id: "light",
|
||||
name: "Light",
|
||||
colors: {
|
||||
node_slot: {
|
||||
CLIP: "#FFA726",
|
||||
// orange
|
||||
CLIP_VISION: "#5C6BC0",
|
||||
// indigo
|
||||
CLIP_VISION_OUTPUT: "#8D6E63",
|
||||
// brown
|
||||
CONDITIONING: "#EF5350",
|
||||
// red
|
||||
CONTROL_NET: "#66BB6A",
|
||||
// green
|
||||
IMAGE: "#42A5F5",
|
||||
// blue
|
||||
LATENT: "#AB47BC",
|
||||
// purple
|
||||
MASK: "#9CCC65",
|
||||
// light green
|
||||
MODEL: "#7E57C2",
|
||||
// deep purple
|
||||
STYLE_MODEL: "#D4E157",
|
||||
// lime
|
||||
VAE: "#FF7043"
|
||||
// deep orange
|
||||
},
|
||||
litegraph_base: {
|
||||
BACKGROUND_IMAGE: "",
|
||||
CLEAR_BACKGROUND_COLOR: "lightgray",
|
||||
NODE_TITLE_COLOR: "#222",
|
||||
NODE_SELECTED_TITLE_COLOR: "#000",
|
||||
NODE_TEXT_SIZE: 14,
|
||||
NODE_TEXT_COLOR: "#444",
|
||||
NODE_SUBTEXT_SIZE: 12,
|
||||
NODE_DEFAULT_COLOR: "#F7F7F7",
|
||||
NODE_DEFAULT_BGCOLOR: "#F5F5F5",
|
||||
NODE_DEFAULT_BOXCOLOR: "#CCC",
|
||||
NODE_DEFAULT_SHAPE: "box",
|
||||
NODE_BOX_OUTLINE_COLOR: "#000",
|
||||
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.1)",
|
||||
DEFAULT_GROUP_FONT: 24,
|
||||
WIDGET_BGCOLOR: "#D4D4D4",
|
||||
WIDGET_OUTLINE_COLOR: "#999",
|
||||
WIDGET_TEXT_COLOR: "#222",
|
||||
WIDGET_SECONDARY_TEXT_COLOR: "#555",
|
||||
LINK_COLOR: "#4CAF50",
|
||||
EVENT_LINK_COLOR: "#FF9800",
|
||||
CONNECTING_LINK_COLOR: "#2196F3",
|
||||
BADGE_FG_COLOR: "#000",
|
||||
BADGE_BG_COLOR: "#FFF"
|
||||
},
|
||||
comfy_base: {
|
||||
"fg-color": "#222",
|
||||
"bg-color": "#DDD",
|
||||
"comfy-menu-bg": "#F5F5F5",
|
||||
"comfy-input-bg": "#C9C9C9",
|
||||
"input-text": "#222",
|
||||
"descrip-text": "#444",
|
||||
"drag-text": "#555",
|
||||
"error-text": "#F44336",
|
||||
"border-color": "#888",
|
||||
"tr-even-bg-color": "#f9f9f9",
|
||||
"tr-odd-bg-color": "#fff",
|
||||
"content-bg": "#e0e0e0",
|
||||
"content-fg": "#222",
|
||||
"content-hover-bg": "#adadad",
|
||||
"content-hover-fg": "#222"
|
||||
}
|
||||
}
|
||||
},
|
||||
solarized: {
|
||||
id: "solarized",
|
||||
name: "Solarized",
|
||||
colors: {
|
||||
node_slot: {
|
||||
CLIP: "#2AB7CA",
|
||||
// light blue
|
||||
CLIP_VISION: "#6c71c4",
|
||||
// blue violet
|
||||
CLIP_VISION_OUTPUT: "#859900",
|
||||
// olive green
|
||||
CONDITIONING: "#d33682",
|
||||
// magenta
|
||||
CONTROL_NET: "#d1ffd7",
|
||||
// light mint green
|
||||
IMAGE: "#5940bb",
|
||||
// deep blue violet
|
||||
LATENT: "#268bd2",
|
||||
// blue
|
||||
MASK: "#CCC9E7",
|
||||
// light purple-gray
|
||||
MODEL: "#dc322f",
|
||||
// red
|
||||
STYLE_MODEL: "#1a998a",
|
||||
// teal
|
||||
UPSCALE_MODEL: "#054A29",
|
||||
// dark green
|
||||
VAE: "#facfad"
|
||||
// light pink-orange
|
||||
},
|
||||
litegraph_base: {
|
||||
NODE_TITLE_COLOR: "#fdf6e3",
|
||||
// Base3
|
||||
NODE_SELECTED_TITLE_COLOR: "#A9D400",
|
||||
NODE_TEXT_SIZE: 14,
|
||||
NODE_TEXT_COLOR: "#657b83",
|
||||
// Base00
|
||||
NODE_SUBTEXT_SIZE: 12,
|
||||
NODE_DEFAULT_COLOR: "#094656",
|
||||
NODE_DEFAULT_BGCOLOR: "#073642",
|
||||
// Base02
|
||||
NODE_DEFAULT_BOXCOLOR: "#839496",
|
||||
// Base0
|
||||
NODE_DEFAULT_SHAPE: "box",
|
||||
NODE_BOX_OUTLINE_COLOR: "#fdf6e3",
|
||||
// Base3
|
||||
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||
DEFAULT_GROUP_FONT: 24,
|
||||
WIDGET_BGCOLOR: "#002b36",
|
||||
// Base03
|
||||
WIDGET_OUTLINE_COLOR: "#839496",
|
||||
// Base0
|
||||
WIDGET_TEXT_COLOR: "#fdf6e3",
|
||||
// Base3
|
||||
WIDGET_SECONDARY_TEXT_COLOR: "#93a1a1",
|
||||
// Base1
|
||||
LINK_COLOR: "#2aa198",
|
||||
// Solarized Cyan
|
||||
EVENT_LINK_COLOR: "#268bd2",
|
||||
// Solarized Blue
|
||||
CONNECTING_LINK_COLOR: "#859900"
|
||||
// Solarized Green
|
||||
},
|
||||
comfy_base: {
|
||||
"fg-color": "#fdf6e3",
|
||||
// Base3
|
||||
"bg-color": "#002b36",
|
||||
// Base03
|
||||
"comfy-menu-bg": "#073642",
|
||||
// Base02
|
||||
"comfy-input-bg": "#002b36",
|
||||
// Base03
|
||||
"input-text": "#93a1a1",
|
||||
// Base1
|
||||
"descrip-text": "#586e75",
|
||||
// Base01
|
||||
"drag-text": "#839496",
|
||||
// Base0
|
||||
"error-text": "#dc322f",
|
||||
// Solarized Red
|
||||
"border-color": "#657b83",
|
||||
// Base00
|
||||
"tr-even-bg-color": "#002b36",
|
||||
"tr-odd-bg-color": "#073642",
|
||||
"content-bg": "#657b83",
|
||||
"content-fg": "#fdf6e3",
|
||||
"content-hover-bg": "#002b36",
|
||||
"content-hover-fg": "#fdf6e3"
|
||||
}
|
||||
}
|
||||
},
|
||||
arc: {
|
||||
id: "arc",
|
||||
name: "Arc",
|
||||
colors: {
|
||||
node_slot: {
|
||||
BOOLEAN: "",
|
||||
CLIP: "#eacb8b",
|
||||
CLIP_VISION: "#A8DADC",
|
||||
CLIP_VISION_OUTPUT: "#ad7452",
|
||||
CONDITIONING: "#cf876f",
|
||||
CONTROL_NET: "#00d78d",
|
||||
CONTROL_NET_WEIGHTS: "",
|
||||
FLOAT: "",
|
||||
GLIGEN: "",
|
||||
IMAGE: "#80a1c0",
|
||||
IMAGEUPLOAD: "",
|
||||
INT: "",
|
||||
LATENT: "#b38ead",
|
||||
LATENT_KEYFRAME: "",
|
||||
MASK: "#a3bd8d",
|
||||
MODEL: "#8978a7",
|
||||
SAMPLER: "",
|
||||
SIGMAS: "",
|
||||
STRING: "",
|
||||
STYLE_MODEL: "#C2FFAE",
|
||||
T2I_ADAPTER_WEIGHTS: "",
|
||||
TAESD: "#DCC274",
|
||||
TIMESTEP_KEYFRAME: "",
|
||||
UPSCALE_MODEL: "",
|
||||
VAE: "#be616b"
|
||||
},
|
||||
litegraph_base: {
|
||||
BACKGROUND_IMAGE: "",
|
||||
CLEAR_BACKGROUND_COLOR: "#2b2f38",
|
||||
NODE_TITLE_COLOR: "#b2b7bd",
|
||||
NODE_SELECTED_TITLE_COLOR: "#FFF",
|
||||
NODE_TEXT_SIZE: 14,
|
||||
NODE_TEXT_COLOR: "#AAA",
|
||||
NODE_SUBTEXT_SIZE: 12,
|
||||
NODE_DEFAULT_COLOR: "#2b2f38",
|
||||
NODE_DEFAULT_BGCOLOR: "#242730",
|
||||
NODE_DEFAULT_BOXCOLOR: "#6e7581",
|
||||
NODE_DEFAULT_SHAPE: "box",
|
||||
NODE_BOX_OUTLINE_COLOR: "#FFF",
|
||||
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||
DEFAULT_GROUP_FONT: 22,
|
||||
WIDGET_BGCOLOR: "#2b2f38",
|
||||
WIDGET_OUTLINE_COLOR: "#6e7581",
|
||||
WIDGET_TEXT_COLOR: "#DDD",
|
||||
WIDGET_SECONDARY_TEXT_COLOR: "#b2b7bd",
|
||||
LINK_COLOR: "#9A9",
|
||||
EVENT_LINK_COLOR: "#A86",
|
||||
CONNECTING_LINK_COLOR: "#AFA"
|
||||
},
|
||||
comfy_base: {
|
||||
"fg-color": "#fff",
|
||||
"bg-color": "#2b2f38",
|
||||
"comfy-menu-bg": "#242730",
|
||||
"comfy-input-bg": "#2b2f38",
|
||||
"input-text": "#ddd",
|
||||
"descrip-text": "#b2b7bd",
|
||||
"drag-text": "#ccc",
|
||||
"error-text": "#ff4444",
|
||||
"border-color": "#6e7581",
|
||||
"tr-even-bg-color": "#2b2f38",
|
||||
"tr-odd-bg-color": "#242730",
|
||||
"content-bg": "#6e7581",
|
||||
"content-fg": "#fff",
|
||||
"content-hover-bg": "#2b2f38",
|
||||
"content-hover-fg": "#fff"
|
||||
}
|
||||
}
|
||||
},
|
||||
nord: {
|
||||
id: "nord",
|
||||
name: "Nord",
|
||||
colors: {
|
||||
node_slot: {
|
||||
BOOLEAN: "",
|
||||
CLIP: "#eacb8b",
|
||||
CLIP_VISION: "#A8DADC",
|
||||
CLIP_VISION_OUTPUT: "#ad7452",
|
||||
CONDITIONING: "#cf876f",
|
||||
CONTROL_NET: "#00d78d",
|
||||
CONTROL_NET_WEIGHTS: "",
|
||||
FLOAT: "",
|
||||
GLIGEN: "",
|
||||
IMAGE: "#80a1c0",
|
||||
IMAGEUPLOAD: "",
|
||||
INT: "",
|
||||
LATENT: "#b38ead",
|
||||
LATENT_KEYFRAME: "",
|
||||
MASK: "#a3bd8d",
|
||||
MODEL: "#8978a7",
|
||||
SAMPLER: "",
|
||||
SIGMAS: "",
|
||||
STRING: "",
|
||||
STYLE_MODEL: "#C2FFAE",
|
||||
T2I_ADAPTER_WEIGHTS: "",
|
||||
TAESD: "#DCC274",
|
||||
TIMESTEP_KEYFRAME: "",
|
||||
UPSCALE_MODEL: "",
|
||||
VAE: "#be616b"
|
||||
},
|
||||
litegraph_base: {
|
||||
BACKGROUND_IMAGE: "",
|
||||
CLEAR_BACKGROUND_COLOR: "#212732",
|
||||
NODE_TITLE_COLOR: "#999",
|
||||
NODE_SELECTED_TITLE_COLOR: "#e5eaf0",
|
||||
NODE_TEXT_SIZE: 14,
|
||||
NODE_TEXT_COLOR: "#bcc2c8",
|
||||
NODE_SUBTEXT_SIZE: 12,
|
||||
NODE_DEFAULT_COLOR: "#2e3440",
|
||||
NODE_DEFAULT_BGCOLOR: "#161b22",
|
||||
NODE_DEFAULT_BOXCOLOR: "#545d70",
|
||||
NODE_DEFAULT_SHAPE: "box",
|
||||
NODE_BOX_OUTLINE_COLOR: "#e5eaf0",
|
||||
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||
DEFAULT_GROUP_FONT: 24,
|
||||
WIDGET_BGCOLOR: "#2e3440",
|
||||
WIDGET_OUTLINE_COLOR: "#545d70",
|
||||
WIDGET_TEXT_COLOR: "#bcc2c8",
|
||||
WIDGET_SECONDARY_TEXT_COLOR: "#999",
|
||||
LINK_COLOR: "#9A9",
|
||||
EVENT_LINK_COLOR: "#A86",
|
||||
CONNECTING_LINK_COLOR: "#AFA"
|
||||
},
|
||||
comfy_base: {
|
||||
"fg-color": "#e5eaf0",
|
||||
"bg-color": "#2e3440",
|
||||
"comfy-menu-bg": "#161b22",
|
||||
"comfy-input-bg": "#2e3440",
|
||||
"input-text": "#bcc2c8",
|
||||
"descrip-text": "#999",
|
||||
"drag-text": "#ccc",
|
||||
"error-text": "#ff4444",
|
||||
"border-color": "#545d70",
|
||||
"tr-even-bg-color": "#2e3440",
|
||||
"tr-odd-bg-color": "#161b22",
|
||||
"content-bg": "#545d70",
|
||||
"content-fg": "#e5eaf0",
|
||||
"content-hover-bg": "#2e3440",
|
||||
"content-hover-fg": "#e5eaf0"
|
||||
}
|
||||
}
|
||||
},
|
||||
github: {
|
||||
id: "github",
|
||||
name: "Github",
|
||||
colors: {
|
||||
node_slot: {
|
||||
BOOLEAN: "",
|
||||
CLIP: "#eacb8b",
|
||||
CLIP_VISION: "#A8DADC",
|
||||
CLIP_VISION_OUTPUT: "#ad7452",
|
||||
CONDITIONING: "#cf876f",
|
||||
CONTROL_NET: "#00d78d",
|
||||
CONTROL_NET_WEIGHTS: "",
|
||||
FLOAT: "",
|
||||
GLIGEN: "",
|
||||
IMAGE: "#80a1c0",
|
||||
IMAGEUPLOAD: "",
|
||||
INT: "",
|
||||
LATENT: "#b38ead",
|
||||
LATENT_KEYFRAME: "",
|
||||
MASK: "#a3bd8d",
|
||||
MODEL: "#8978a7",
|
||||
SAMPLER: "",
|
||||
SIGMAS: "",
|
||||
STRING: "",
|
||||
STYLE_MODEL: "#C2FFAE",
|
||||
T2I_ADAPTER_WEIGHTS: "",
|
||||
TAESD: "#DCC274",
|
||||
TIMESTEP_KEYFRAME: "",
|
||||
UPSCALE_MODEL: "",
|
||||
VAE: "#be616b"
|
||||
},
|
||||
litegraph_base: {
|
||||
BACKGROUND_IMAGE: "",
|
||||
CLEAR_BACKGROUND_COLOR: "#040506",
|
||||
NODE_TITLE_COLOR: "#999",
|
||||
NODE_SELECTED_TITLE_COLOR: "#e5eaf0",
|
||||
NODE_TEXT_SIZE: 14,
|
||||
NODE_TEXT_COLOR: "#bcc2c8",
|
||||
NODE_SUBTEXT_SIZE: 12,
|
||||
NODE_DEFAULT_COLOR: "#161b22",
|
||||
NODE_DEFAULT_BGCOLOR: "#13171d",
|
||||
NODE_DEFAULT_BOXCOLOR: "#30363d",
|
||||
NODE_DEFAULT_SHAPE: "box",
|
||||
NODE_BOX_OUTLINE_COLOR: "#e5eaf0",
|
||||
NODE_BYPASS_BGCOLOR: "#FF00FF",
|
||||
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||
DEFAULT_GROUP_FONT: 24,
|
||||
WIDGET_BGCOLOR: "#161b22",
|
||||
WIDGET_OUTLINE_COLOR: "#30363d",
|
||||
WIDGET_TEXT_COLOR: "#bcc2c8",
|
||||
WIDGET_SECONDARY_TEXT_COLOR: "#999",
|
||||
LINK_COLOR: "#9A9",
|
||||
EVENT_LINK_COLOR: "#A86",
|
||||
CONNECTING_LINK_COLOR: "#AFA"
|
||||
},
|
||||
comfy_base: {
|
||||
"fg-color": "#e5eaf0",
|
||||
"bg-color": "#161b22",
|
||||
"comfy-menu-bg": "#13171d",
|
||||
"comfy-input-bg": "#161b22",
|
||||
"input-text": "#bcc2c8",
|
||||
"descrip-text": "#999",
|
||||
"drag-text": "#ccc",
|
||||
"error-text": "#ff4444",
|
||||
"border-color": "#30363d",
|
||||
"tr-even-bg-color": "#161b22",
|
||||
"tr-odd-bg-color": "#13171d",
|
||||
"content-bg": "#30363d",
|
||||
"content-fg": "#e5eaf0",
|
||||
"content-hover-bg": "#161b22",
|
||||
"content-hover-fg": "#e5eaf0"
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
const id = "Comfy.ColorPalette";
|
||||
const idCustomColorPalettes = "Comfy.CustomColorPalettes";
|
||||
const defaultColorPaletteId = "dark";
|
||||
const els = {
|
||||
select: null
|
||||
};
|
||||
const getCustomColorPalettes = /* @__PURE__ */ __name(() => {
|
||||
return app.ui.settings.getSettingValue(idCustomColorPalettes, {});
|
||||
}, "getCustomColorPalettes");
|
||||
const setCustomColorPalettes = /* @__PURE__ */ __name((customColorPalettes) => {
|
||||
return app.ui.settings.setSettingValue(
|
||||
idCustomColorPalettes,
|
||||
customColorPalettes
|
||||
);
|
||||
}, "setCustomColorPalettes");
|
||||
const defaultColorPalette = colorPalettes[defaultColorPaletteId];
|
||||
const getColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
|
||||
if (!colorPaletteId) {
|
||||
colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
|
||||
}
|
||||
if (colorPaletteId.startsWith("custom_")) {
|
||||
colorPaletteId = colorPaletteId.substr(7);
|
||||
let customColorPalettes = getCustomColorPalettes();
|
||||
if (customColorPalettes[colorPaletteId]) {
|
||||
return customColorPalettes[colorPaletteId];
|
||||
}
|
||||
}
|
||||
return colorPalettes[colorPaletteId];
|
||||
}, "getColorPalette");
|
||||
const setColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
|
||||
app.ui.settings.setSettingValue(id, colorPaletteId);
|
||||
}, "setColorPalette");
|
||||
app.registerExtension({
|
||||
name: id,
|
||||
init() {
|
||||
LGraphCanvas.prototype.updateBackground = function(image, clearBackgroundColor) {
|
||||
this._bg_img = new Image();
|
||||
this._bg_img.name = image;
|
||||
this._bg_img.src = image;
|
||||
this._bg_img.onload = () => {
|
||||
this.draw(true, true);
|
||||
};
|
||||
this.background_image = image;
|
||||
this.clear_background = true;
|
||||
this.clear_background_color = clearBackgroundColor;
|
||||
this._pattern = null;
|
||||
};
|
||||
},
|
||||
addCustomNodeDefs(node_defs) {
|
||||
const sortObjectKeys = /* @__PURE__ */ __name((unordered) => {
|
||||
return Object.keys(unordered).sort().reduce((obj, key) => {
|
||||
obj[key] = unordered[key];
|
||||
return obj;
|
||||
}, {});
|
||||
}, "sortObjectKeys");
|
||||
function getSlotTypes() {
|
||||
var types = [];
|
||||
const defs = node_defs;
|
||||
for (const nodeId in defs) {
|
||||
const nodeData = defs[nodeId];
|
||||
var inputs = nodeData["input"]["required"];
|
||||
if (nodeData["input"]["optional"] !== void 0) {
|
||||
inputs = Object.assign(
|
||||
{},
|
||||
nodeData["input"]["required"],
|
||||
nodeData["input"]["optional"]
|
||||
);
|
||||
}
|
||||
for (const inputName in inputs) {
|
||||
const inputData = inputs[inputName];
|
||||
const type = inputData[0];
|
||||
if (!Array.isArray(type)) {
|
||||
types.push(type);
|
||||
}
|
||||
}
|
||||
for (const o in nodeData["output"]) {
|
||||
const output = nodeData["output"][o];
|
||||
types.push(output);
|
||||
}
|
||||
}
|
||||
return types;
|
||||
}
|
||||
__name(getSlotTypes, "getSlotTypes");
|
||||
function completeColorPalette(colorPalette) {
|
||||
var types = getSlotTypes();
|
||||
for (const type of types) {
|
||||
if (!colorPalette.colors.node_slot[type]) {
|
||||
colorPalette.colors.node_slot[type] = "";
|
||||
}
|
||||
}
|
||||
colorPalette.colors.node_slot = sortObjectKeys(
|
||||
colorPalette.colors.node_slot
|
||||
);
|
||||
return colorPalette;
|
||||
}
|
||||
__name(completeColorPalette, "completeColorPalette");
|
||||
const getColorPaletteTemplate = /* @__PURE__ */ __name(async () => {
|
||||
let colorPalette = {
|
||||
id: "my_color_palette_unique_id",
|
||||
name: "My Color Palette",
|
||||
colors: {
|
||||
node_slot: {},
|
||||
litegraph_base: {},
|
||||
comfy_base: {}
|
||||
}
|
||||
};
|
||||
const defaultColorPalette2 = colorPalettes[defaultColorPaletteId];
|
||||
for (const key in defaultColorPalette2.colors.litegraph_base) {
|
||||
if (!colorPalette.colors.litegraph_base[key]) {
|
||||
colorPalette.colors.litegraph_base[key] = "";
|
||||
}
|
||||
}
|
||||
for (const key in defaultColorPalette2.colors.comfy_base) {
|
||||
if (!colorPalette.colors.comfy_base[key]) {
|
||||
colorPalette.colors.comfy_base[key] = "";
|
||||
}
|
||||
}
|
||||
return completeColorPalette(colorPalette);
|
||||
}, "getColorPaletteTemplate");
|
||||
const addCustomColorPalette = /* @__PURE__ */ __name(async (colorPalette) => {
|
||||
if (typeof colorPalette !== "object") {
|
||||
useToastStore().addAlert("Invalid color palette.");
|
||||
return;
|
||||
}
|
||||
if (!colorPalette.id) {
|
||||
useToastStore().addAlert("Color palette missing id.");
|
||||
return;
|
||||
}
|
||||
if (!colorPalette.name) {
|
||||
useToastStore().addAlert("Color palette missing name.");
|
||||
return;
|
||||
}
|
||||
if (!colorPalette.colors) {
|
||||
useToastStore().addAlert("Color palette missing colors.");
|
||||
return;
|
||||
}
|
||||
if (colorPalette.colors.node_slot && typeof colorPalette.colors.node_slot !== "object") {
|
||||
useToastStore().addAlert("Invalid color palette colors.node_slot.");
|
||||
return;
|
||||
}
|
||||
const customColorPalettes = getCustomColorPalettes();
|
||||
customColorPalettes[colorPalette.id] = colorPalette;
|
||||
setCustomColorPalettes(customColorPalettes);
|
||||
for (const option of els.select.childNodes) {
|
||||
if (option.value === "custom_" + colorPalette.id) {
|
||||
els.select.removeChild(option);
|
||||
}
|
||||
}
|
||||
els.select.append(
|
||||
$el("option", {
|
||||
textContent: colorPalette.name + " (custom)",
|
||||
value: "custom_" + colorPalette.id,
|
||||
selected: true
|
||||
})
|
||||
);
|
||||
setColorPalette("custom_" + colorPalette.id);
|
||||
await loadColorPalette(colorPalette);
|
||||
}, "addCustomColorPalette");
|
||||
const deleteCustomColorPalette = /* @__PURE__ */ __name(async (colorPaletteId) => {
|
||||
const customColorPalettes = getCustomColorPalettes();
|
||||
delete customColorPalettes[colorPaletteId];
|
||||
setCustomColorPalettes(customColorPalettes);
|
||||
for (const opt of els.select.childNodes) {
|
||||
const option = opt;
|
||||
if (option.value === defaultColorPaletteId) {
|
||||
option.selected = true;
|
||||
}
|
||||
if (option.value === "custom_" + colorPaletteId) {
|
||||
els.select.removeChild(option);
|
||||
}
|
||||
}
|
||||
setColorPalette(defaultColorPaletteId);
|
||||
await loadColorPalette(getColorPalette());
|
||||
}, "deleteCustomColorPalette");
|
||||
const loadColorPalette = /* @__PURE__ */ __name(async (colorPalette) => {
|
||||
colorPalette = await completeColorPalette(colorPalette);
|
||||
if (colorPalette.colors) {
|
||||
if (colorPalette.colors.node_slot) {
|
||||
Object.assign(
|
||||
app.canvas.default_connection_color_byType,
|
||||
colorPalette.colors.node_slot
|
||||
);
|
||||
Object.assign(
|
||||
LGraphCanvas.link_type_colors,
|
||||
colorPalette.colors.node_slot
|
||||
);
|
||||
}
|
||||
if (colorPalette.colors.litegraph_base) {
|
||||
app.canvas.node_title_color = colorPalette.colors.litegraph_base.NODE_TITLE_COLOR;
|
||||
app.canvas.default_link_color = colorPalette.colors.litegraph_base.LINK_COLOR;
|
||||
for (const key in colorPalette.colors.litegraph_base) {
|
||||
if (colorPalette.colors.litegraph_base.hasOwnProperty(key) && LiteGraph.hasOwnProperty(key)) {
|
||||
LiteGraph[key] = colorPalette.colors.litegraph_base[key];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (colorPalette.colors.comfy_base) {
|
||||
const rootStyle = document.documentElement.style;
|
||||
for (const key in colorPalette.colors.comfy_base) {
|
||||
rootStyle.setProperty(
|
||||
"--" + key,
|
||||
colorPalette.colors.comfy_base[key]
|
||||
);
|
||||
}
|
||||
}
|
||||
if (colorPalette.colors.litegraph_base.NODE_BYPASS_BGCOLOR) {
|
||||
app.bypassBgColor = colorPalette.colors.litegraph_base.NODE_BYPASS_BGCOLOR;
|
||||
}
|
||||
app.canvas.draw(true, true);
|
||||
}
|
||||
}, "loadColorPalette");
|
||||
const fileInput = $el("input", {
|
||||
type: "file",
|
||||
accept: ".json",
|
||||
style: { display: "none" },
|
||||
parent: document.body,
|
||||
onchange: /* @__PURE__ */ __name(() => {
|
||||
const file = fileInput.files[0];
|
||||
if (file.type === "application/json" || file.name.endsWith(".json")) {
|
||||
const reader = new FileReader();
|
||||
reader.onload = async () => {
|
||||
await addCustomColorPalette(JSON.parse(reader.result));
|
||||
};
|
||||
reader.readAsText(file);
|
||||
}
|
||||
}, "onchange")
|
||||
});
|
||||
app.ui.settings.addSetting({
|
||||
id,
|
||||
category: ["Comfy", "ColorPalette"],
|
||||
name: "Color Palette",
|
||||
type: /* @__PURE__ */ __name((name, setter, value) => {
|
||||
const options = [
|
||||
...Object.values(colorPalettes).map(
|
||||
(c) => $el("option", {
|
||||
textContent: c.name,
|
||||
value: c.id,
|
||||
selected: c.id === value
|
||||
})
|
||||
),
|
||||
...Object.values(getCustomColorPalettes()).map(
|
||||
(c) => $el("option", {
|
||||
textContent: `${c.name} (custom)`,
|
||||
value: `custom_${c.id}`,
|
||||
selected: `custom_${c.id}` === value
|
||||
})
|
||||
)
|
||||
];
|
||||
els.select = $el(
|
||||
"select",
|
||||
{
|
||||
style: {
|
||||
marginBottom: "0.15rem",
|
||||
width: "100%"
|
||||
},
|
||||
onchange: /* @__PURE__ */ __name((e) => {
|
||||
setter(e.target.value);
|
||||
}, "onchange")
|
||||
},
|
||||
options
|
||||
);
|
||||
return $el("tr", [
|
||||
$el("td", [
|
||||
els.select,
|
||||
$el(
|
||||
"div",
|
||||
{
|
||||
style: {
|
||||
display: "grid",
|
||||
gap: "4px",
|
||||
gridAutoFlow: "column"
|
||||
}
|
||||
},
|
||||
[
|
||||
$el("input", {
|
||||
type: "button",
|
||||
value: "Export",
|
||||
onclick: /* @__PURE__ */ __name(async () => {
|
||||
const colorPaletteId = app.ui.settings.getSettingValue(
|
||||
id,
|
||||
defaultColorPaletteId
|
||||
);
|
||||
const colorPalette = await completeColorPalette(
|
||||
getColorPalette(colorPaletteId)
|
||||
);
|
||||
const json = JSON.stringify(colorPalette, null, 2);
|
||||
const blob = new Blob([json], { type: "application/json" });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = $el("a", {
|
||||
href: url,
|
||||
download: colorPaletteId + ".json",
|
||||
style: { display: "none" },
|
||||
parent: document.body
|
||||
});
|
||||
a.click();
|
||||
setTimeout(function() {
|
||||
a.remove();
|
||||
window.URL.revokeObjectURL(url);
|
||||
}, 0);
|
||||
}, "onclick")
|
||||
}),
|
||||
$el("input", {
|
||||
type: "button",
|
||||
value: "Import",
|
||||
onclick: /* @__PURE__ */ __name(() => {
|
||||
fileInput.click();
|
||||
}, "onclick")
|
||||
}),
|
||||
$el("input", {
|
||||
type: "button",
|
||||
value: "Template",
|
||||
onclick: /* @__PURE__ */ __name(async () => {
|
||||
const colorPalette = await getColorPaletteTemplate();
|
||||
const json = JSON.stringify(colorPalette, null, 2);
|
||||
const blob = new Blob([json], { type: "application/json" });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = $el("a", {
|
||||
href: url,
|
||||
download: "color_palette.json",
|
||||
style: { display: "none" },
|
||||
parent: document.body
|
||||
});
|
||||
a.click();
|
||||
setTimeout(function() {
|
||||
a.remove();
|
||||
window.URL.revokeObjectURL(url);
|
||||
}, 0);
|
||||
}, "onclick")
|
||||
}),
|
||||
$el("input", {
|
||||
type: "button",
|
||||
value: "Delete",
|
||||
onclick: /* @__PURE__ */ __name(async () => {
|
||||
let colorPaletteId = app.ui.settings.getSettingValue(
|
||||
id,
|
||||
defaultColorPaletteId
|
||||
);
|
||||
if (colorPalettes[colorPaletteId]) {
|
||||
useToastStore().addAlert(
|
||||
"You cannot delete a built-in color palette."
|
||||
);
|
||||
return;
|
||||
}
|
||||
if (colorPaletteId.startsWith("custom_")) {
|
||||
colorPaletteId = colorPaletteId.substr(7);
|
||||
}
|
||||
await deleteCustomColorPalette(colorPaletteId);
|
||||
}, "onclick")
|
||||
})
|
||||
]
|
||||
)
|
||||
])
|
||||
]);
|
||||
}, "type"),
|
||||
defaultValue: defaultColorPaletteId,
|
||||
async onChange(value) {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
let palette = colorPalettes[value];
|
||||
if (palette) {
|
||||
await loadColorPalette(palette);
|
||||
} else if (value.startsWith("custom_")) {
|
||||
value = value.substr(7);
|
||||
let customColorPalettes = getCustomColorPalettes();
|
||||
if (customColorPalettes[value]) {
|
||||
palette = customColorPalettes[value];
|
||||
await loadColorPalette(customColorPalettes[value]);
|
||||
}
|
||||
}
|
||||
let { BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR } = palette.colors.litegraph_base;
|
||||
if (BACKGROUND_IMAGE === void 0 || CLEAR_BACKGROUND_COLOR === void 0) {
|
||||
const base = colorPalettes["dark"].colors.litegraph_base;
|
||||
BACKGROUND_IMAGE = base.BACKGROUND_IMAGE;
|
||||
CLEAR_BACKGROUND_COLOR = base.CLEAR_BACKGROUND_COLOR;
|
||||
}
|
||||
app.canvas.updateBackground(BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
window.comfyAPI = window.comfyAPI || {};
|
||||
window.comfyAPI.colorPalette = window.comfyAPI.colorPalette || {};
|
||||
window.comfyAPI.colorPalette.defaultColorPalette = defaultColorPalette;
|
||||
window.comfyAPI.colorPalette.getColorPalette = getColorPalette;
|
||||
export {
|
||||
defaultColorPalette as d,
|
||||
getColorPalette as g
|
||||
};
|
||||
//# sourceMappingURL=colorPalette-D5oi2-2V.js.map
|
||||
1
web/assets/colorPalette-D5oi2-2V.js.map
generated
vendored
Normal file
1
web/assets/colorPalette-D5oi2-2V.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
web/assets/index-BD-Ia1C4.js.map
generated
vendored
1
web/assets/index-BD-Ia1C4.js.map
generated
vendored
File diff suppressed because one or more lines are too long
5673
web/assets/index-_5czGnTA.css → web/assets/index-BHJGjcJh.css
generated
vendored
5673
web/assets/index-_5czGnTA.css → web/assets/index-BHJGjcJh.css
generated
vendored
File diff suppressed because it is too large
Load Diff
2018
web/assets/index-BD-Ia1C4.js → web/assets/index-BMC1ey-i.js
generated
vendored
2018
web/assets/index-BD-Ia1C4.js → web/assets/index-BMC1ey-i.js
generated
vendored
File diff suppressed because it is too large
Load Diff
1
web/assets/index-BMC1ey-i.js.map
generated
vendored
Normal file
1
web/assets/index-BMC1ey-i.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
0
web/assets/index-DjWyclij.css → web/assets/index-BRhY6FpL.css
generated
vendored
0
web/assets/index-DjWyclij.css → web/assets/index-BRhY6FpL.css
generated
vendored
1
web/assets/index-CI3N807S.js.map
generated
vendored
1
web/assets/index-CI3N807S.js.map
generated
vendored
File diff suppressed because one or more lines are too long
164716
web/assets/index-CI3N807S.js → web/assets/index-DGAbdBYF.js
generated
vendored
164716
web/assets/index-CI3N807S.js → web/assets/index-DGAbdBYF.js
generated
vendored
File diff suppressed because one or more lines are too long
1
web/assets/index-DGAbdBYF.js.map
generated
vendored
Normal file
1
web/assets/index-DGAbdBYF.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
2602
web/assets/sorted-custom-node-map.json
generated
vendored
Normal file
2602
web/assets/sorted-custom-node-map.json
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
34
web/assets/userSelection-BGzn1LuN.css → web/assets/userSelection-CmI-fOSC.css
generated
vendored
34
web/assets/userSelection-BGzn1LuN.css → web/assets/userSelection-CmI-fOSC.css
generated
vendored
@@ -1,3 +1,37 @@
|
||||
.lds-ring {
|
||||
display: inline-block;
|
||||
position: relative;
|
||||
width: 1em;
|
||||
height: 1em;
|
||||
}
|
||||
.lds-ring div {
|
||||
box-sizing: border-box;
|
||||
display: block;
|
||||
position: absolute;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
border: 0.15em solid #fff;
|
||||
border-radius: 50%;
|
||||
animation: lds-ring 1.2s cubic-bezier(0.5, 0, 0.5, 1) infinite;
|
||||
border-color: #fff transparent transparent transparent;
|
||||
}
|
||||
.lds-ring div:nth-child(1) {
|
||||
animation-delay: -0.45s;
|
||||
}
|
||||
.lds-ring div:nth-child(2) {
|
||||
animation-delay: -0.3s;
|
||||
}
|
||||
.lds-ring div:nth-child(3) {
|
||||
animation-delay: -0.15s;
|
||||
}
|
||||
@keyframes lds-ring {
|
||||
0% {
|
||||
transform: rotate(0deg);
|
||||
}
|
||||
100% {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
.comfy-user-selection {
|
||||
width: 100vw;
|
||||
height: 100vh;
|
||||
1
web/assets/userSelection-CyXKCVy3.js.map
generated
vendored
1
web/assets/userSelection-CyXKCVy3.js.map
generated
vendored
File diff suppressed because one or more lines are too long
13
web/assets/userSelection-CyXKCVy3.js → web/assets/userSelection-Duxc-t_S.js
generated
vendored
13
web/assets/userSelection-CyXKCVy3.js → web/assets/userSelection-Duxc-t_S.js
generated
vendored
@@ -1,6 +1,15 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { j as createSpinner, h as api, $ as $el } from "./index-CI3N807S.js";
|
||||
import { b4 as api, ca as $el } from "./index-DGAbdBYF.js";
|
||||
function createSpinner() {
|
||||
const div = document.createElement("div");
|
||||
div.innerHTML = `<div class="lds-ring"><div></div><div></div><div></div><div></div></div>`;
|
||||
return div.firstElementChild;
|
||||
}
|
||||
__name(createSpinner, "createSpinner");
|
||||
window.comfyAPI = window.comfyAPI || {};
|
||||
window.comfyAPI.spinner = window.comfyAPI.spinner || {};
|
||||
window.comfyAPI.spinner.createSpinner = createSpinner;
|
||||
class UserSelectionScreen {
|
||||
static {
|
||||
__name(this, "UserSelectionScreen");
|
||||
@@ -117,4 +126,4 @@ window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
|
||||
export {
|
||||
UserSelectionScreen
|
||||
};
|
||||
//# sourceMappingURL=userSelection-CyXKCVy3.js.map
|
||||
//# sourceMappingURL=userSelection-Duxc-t_S.js.map
|
||||
1
web/assets/userSelection-Duxc-t_S.js.map
generated
vendored
Normal file
1
web/assets/userSelection-Duxc-t_S.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
756
web/assets/widgetInputs-DdoWwzg5.js
generated
vendored
Normal file
756
web/assets/widgetInputs-DdoWwzg5.js
generated
vendored
Normal file
@@ -0,0 +1,756 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { l as LGraphNode, k as app, cf as applyTextReplacements, ce as ComfyWidgets, ci as addValueControlWidgets, z as LiteGraph } from "./index-DGAbdBYF.js";
|
||||
const CONVERTED_TYPE = "converted-widget";
|
||||
const VALID_TYPES = [
|
||||
"STRING",
|
||||
"combo",
|
||||
"number",
|
||||
"toggle",
|
||||
"BOOLEAN",
|
||||
"text",
|
||||
"string"
|
||||
];
|
||||
const CONFIG = Symbol();
|
||||
const GET_CONFIG = Symbol();
|
||||
const TARGET = Symbol();
|
||||
const replacePropertyName = "Run widget replace on values";
|
||||
class PrimitiveNode extends LGraphNode {
|
||||
static {
|
||||
__name(this, "PrimitiveNode");
|
||||
}
|
||||
controlValues;
|
||||
lastType;
|
||||
static category;
|
||||
constructor(title) {
|
||||
super(title);
|
||||
this.addOutput("connect to widget input", "*");
|
||||
this.serialize_widgets = true;
|
||||
this.isVirtualNode = true;
|
||||
if (!this.properties || !(replacePropertyName in this.properties)) {
|
||||
this.addProperty(replacePropertyName, false, "boolean");
|
||||
}
|
||||
}
|
||||
applyToGraph(extraLinks = []) {
|
||||
if (!this.outputs[0].links?.length) return;
|
||||
function get_links(node) {
|
||||
let links2 = [];
|
||||
for (const l of node.outputs[0].links) {
|
||||
const linkInfo = app.graph.links[l];
|
||||
const n = node.graph.getNodeById(linkInfo.target_id);
|
||||
if (n.type == "Reroute") {
|
||||
links2 = links2.concat(get_links(n));
|
||||
} else {
|
||||
links2.push(l);
|
||||
}
|
||||
}
|
||||
return links2;
|
||||
}
|
||||
__name(get_links, "get_links");
|
||||
let links = [
|
||||
...get_links(this).map((l) => app.graph.links[l]),
|
||||
...extraLinks
|
||||
];
|
||||
let v = this.widgets?.[0].value;
|
||||
if (v && this.properties[replacePropertyName]) {
|
||||
v = applyTextReplacements(app, v);
|
||||
}
|
||||
for (const linkInfo of links) {
|
||||
const node = this.graph.getNodeById(linkInfo.target_id);
|
||||
const input = node.inputs[linkInfo.target_slot];
|
||||
let widget;
|
||||
if (input.widget[TARGET]) {
|
||||
widget = input.widget[TARGET];
|
||||
} else {
|
||||
const widgetName = input.widget.name;
|
||||
if (widgetName) {
|
||||
widget = node.widgets.find((w) => w.name === widgetName);
|
||||
}
|
||||
}
|
||||
if (widget) {
|
||||
widget.value = v;
|
||||
if (widget.callback) {
|
||||
widget.callback(
|
||||
widget.value,
|
||||
app.canvas,
|
||||
node,
|
||||
app.canvas.graph_mouse,
|
||||
{}
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
refreshComboInNode() {
|
||||
const widget = this.widgets?.[0];
|
||||
if (widget?.type === "combo") {
|
||||
widget.options.values = this.outputs[0].widget[GET_CONFIG]()[0];
|
||||
if (!widget.options.values.includes(widget.value)) {
|
||||
widget.value = widget.options.values[0];
|
||||
widget.callback(widget.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
onAfterGraphConfigured() {
|
||||
if (this.outputs[0].links?.length && !this.widgets?.length) {
|
||||
if (!this.#onFirstConnection()) return;
|
||||
if (this.widgets) {
|
||||
for (let i = 0; i < this.widgets_values.length; i++) {
|
||||
const w = this.widgets[i];
|
||||
if (w) {
|
||||
w.value = this.widgets_values[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
this.#mergeWidgetConfig();
|
||||
}
|
||||
}
|
||||
onConnectionsChange(_, index, connected) {
|
||||
if (app.configuringGraph) {
|
||||
return;
|
||||
}
|
||||
const links = this.outputs[0].links;
|
||||
if (connected) {
|
||||
if (links?.length && !this.widgets?.length) {
|
||||
this.#onFirstConnection();
|
||||
}
|
||||
} else {
|
||||
this.#mergeWidgetConfig();
|
||||
if (!links?.length) {
|
||||
this.onLastDisconnect();
|
||||
}
|
||||
}
|
||||
}
|
||||
onConnectOutput(slot, type, input, target_node, target_slot) {
|
||||
if (!input.widget) {
|
||||
if (!(input.type in ComfyWidgets)) return false;
|
||||
}
|
||||
if (this.outputs[slot].links?.length) {
|
||||
const valid = this.#isValidConnection(input);
|
||||
if (valid) {
|
||||
this.applyToGraph([{ target_id: target_node.id, target_slot }]);
|
||||
}
|
||||
return valid;
|
||||
}
|
||||
}
|
||||
#onFirstConnection(recreating) {
|
||||
if (!this.outputs[0].links) {
|
||||
this.onLastDisconnect();
|
||||
return;
|
||||
}
|
||||
const linkId = this.outputs[0].links[0];
|
||||
const link = this.graph.links[linkId];
|
||||
if (!link) return;
|
||||
const theirNode = this.graph.getNodeById(link.target_id);
|
||||
if (!theirNode || !theirNode.inputs) return;
|
||||
const input = theirNode.inputs[link.target_slot];
|
||||
if (!input) return;
|
||||
let widget;
|
||||
if (!input.widget) {
|
||||
if (!(input.type in ComfyWidgets)) return;
|
||||
widget = { name: input.name, [GET_CONFIG]: () => [input.type, {}] };
|
||||
} else {
|
||||
widget = input.widget;
|
||||
}
|
||||
const config = widget[GET_CONFIG]?.();
|
||||
if (!config) return;
|
||||
const { type } = getWidgetType(config);
|
||||
this.outputs[0].type = type;
|
||||
this.outputs[0].name = type;
|
||||
this.outputs[0].widget = widget;
|
||||
this.#createWidget(
|
||||
widget[CONFIG] ?? config,
|
||||
theirNode,
|
||||
widget.name,
|
||||
recreating,
|
||||
widget[TARGET]
|
||||
);
|
||||
}
|
||||
#createWidget(inputData, node, widgetName, recreating, targetWidget) {
|
||||
let type = inputData[0];
|
||||
if (type instanceof Array) {
|
||||
type = "COMBO";
|
||||
}
|
||||
const size = this.size;
|
||||
let widget;
|
||||
if (type in ComfyWidgets) {
|
||||
widget = (ComfyWidgets[type](this, "value", inputData, app) || {}).widget;
|
||||
} else {
|
||||
widget = this.addWidget(type, "value", null, () => {
|
||||
}, {});
|
||||
}
|
||||
if (targetWidget) {
|
||||
widget.value = targetWidget.value;
|
||||
} else if (node?.widgets && widget) {
|
||||
const theirWidget = node.widgets.find((w) => w.name === widgetName);
|
||||
if (theirWidget) {
|
||||
widget.value = theirWidget.value;
|
||||
}
|
||||
}
|
||||
if (!inputData?.[1]?.control_after_generate && (widget.type === "number" || widget.type === "combo")) {
|
||||
let control_value = this.widgets_values?.[1];
|
||||
if (!control_value) {
|
||||
control_value = "fixed";
|
||||
}
|
||||
addValueControlWidgets(
|
||||
this,
|
||||
widget,
|
||||
control_value,
|
||||
void 0,
|
||||
inputData
|
||||
);
|
||||
let filter = this.widgets_values?.[2];
|
||||
if (filter && this.widgets.length === 3) {
|
||||
this.widgets[2].value = filter;
|
||||
}
|
||||
}
|
||||
const controlValues = this.controlValues;
|
||||
if (this.lastType === this.widgets[0].type && controlValues?.length === this.widgets.length - 1) {
|
||||
for (let i = 0; i < controlValues.length; i++) {
|
||||
this.widgets[i + 1].value = controlValues[i];
|
||||
}
|
||||
}
|
||||
const callback = widget.callback;
|
||||
const self = this;
|
||||
widget.callback = function() {
|
||||
const r = callback ? callback.apply(this, arguments) : void 0;
|
||||
self.applyToGraph();
|
||||
return r;
|
||||
};
|
||||
this.size = [
|
||||
Math.max(this.size[0], size[0]),
|
||||
Math.max(this.size[1], size[1])
|
||||
];
|
||||
if (!recreating) {
|
||||
const sz = this.computeSize();
|
||||
if (this.size[0] < sz[0]) {
|
||||
this.size[0] = sz[0];
|
||||
}
|
||||
if (this.size[1] < sz[1]) {
|
||||
this.size[1] = sz[1];
|
||||
}
|
||||
requestAnimationFrame(() => {
|
||||
if (this.onResize) {
|
||||
this.onResize(this.size);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
recreateWidget() {
|
||||
const values = this.widgets?.map((w) => w.value);
|
||||
this.#removeWidgets();
|
||||
this.#onFirstConnection(true);
|
||||
if (values?.length) {
|
||||
for (let i = 0; i < this.widgets?.length; i++)
|
||||
this.widgets[i].value = values[i];
|
||||
}
|
||||
return this.widgets?.[0];
|
||||
}
|
||||
#mergeWidgetConfig() {
|
||||
const output = this.outputs[0];
|
||||
const links = output.links;
|
||||
const hasConfig = !!output.widget[CONFIG];
|
||||
if (hasConfig) {
|
||||
delete output.widget[CONFIG];
|
||||
}
|
||||
if (links?.length < 2 && hasConfig) {
|
||||
if (links.length) {
|
||||
this.recreateWidget();
|
||||
}
|
||||
return;
|
||||
}
|
||||
const config1 = output.widget[GET_CONFIG]();
|
||||
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
|
||||
if (!isNumber) return;
|
||||
for (const linkId of links) {
|
||||
const link = app.graph.links[linkId];
|
||||
if (!link) continue;
|
||||
const theirNode = app.graph.getNodeById(link.target_id);
|
||||
const theirInput = theirNode.inputs[link.target_slot];
|
||||
this.#isValidConnection(theirInput, hasConfig);
|
||||
}
|
||||
}
|
||||
isValidWidgetLink(originSlot, targetNode, targetWidget) {
|
||||
const config2 = getConfig.call(targetNode, targetWidget.name) ?? [
|
||||
targetWidget.type,
|
||||
targetWidget.options || {}
|
||||
];
|
||||
if (!isConvertibleWidget(targetWidget, config2)) return false;
|
||||
const output = this.outputs[originSlot];
|
||||
if (!(output.widget?.[CONFIG] ?? output.widget?.[GET_CONFIG]())) {
|
||||
return true;
|
||||
}
|
||||
return !!mergeIfValid.call(this, output, config2);
|
||||
}
|
||||
#isValidConnection(input, forceUpdate) {
|
||||
const output = this.outputs[0];
|
||||
const config2 = input.widget[GET_CONFIG]();
|
||||
return !!mergeIfValid.call(
|
||||
this,
|
||||
output,
|
||||
config2,
|
||||
forceUpdate,
|
||||
this.recreateWidget
|
||||
);
|
||||
}
|
||||
#removeWidgets() {
|
||||
if (this.widgets) {
|
||||
for (const w of this.widgets) {
|
||||
if (w.onRemove) {
|
||||
w.onRemove();
|
||||
}
|
||||
}
|
||||
this.controlValues = [];
|
||||
this.lastType = this.widgets[0]?.type;
|
||||
for (let i = 1; i < this.widgets.length; i++) {
|
||||
this.controlValues.push(this.widgets[i].value);
|
||||
}
|
||||
setTimeout(() => {
|
||||
delete this.lastType;
|
||||
delete this.controlValues;
|
||||
}, 15);
|
||||
this.widgets.length = 0;
|
||||
}
|
||||
}
|
||||
onLastDisconnect() {
|
||||
this.outputs[0].type = "*";
|
||||
this.outputs[0].name = "connect to widget input";
|
||||
delete this.outputs[0].widget;
|
||||
this.#removeWidgets();
|
||||
}
|
||||
}
|
||||
function getWidgetConfig(slot) {
|
||||
return slot.widget[CONFIG] ?? slot.widget[GET_CONFIG]();
|
||||
}
|
||||
__name(getWidgetConfig, "getWidgetConfig");
|
||||
function getConfig(widgetName) {
|
||||
const { nodeData } = this.constructor;
|
||||
return nodeData?.input?.required?.[widgetName] ?? nodeData?.input?.optional?.[widgetName];
|
||||
}
|
||||
__name(getConfig, "getConfig");
|
||||
function isConvertibleWidget(widget, config) {
|
||||
return (VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0])) && !widget.options?.forceInput;
|
||||
}
|
||||
__name(isConvertibleWidget, "isConvertibleWidget");
|
||||
function hideWidget(node, widget, suffix = "") {
|
||||
if (widget.type?.startsWith(CONVERTED_TYPE)) return;
|
||||
widget.origType = widget.type;
|
||||
widget.origComputeSize = widget.computeSize;
|
||||
widget.origSerializeValue = widget.serializeValue;
|
||||
widget.computeSize = () => [0, -4];
|
||||
widget.type = CONVERTED_TYPE + suffix;
|
||||
widget.serializeValue = () => {
|
||||
if (!node.inputs) {
|
||||
return void 0;
|
||||
}
|
||||
let node_input = node.inputs.find((i) => i.widget?.name === widget.name);
|
||||
if (!node_input || !node_input.link) {
|
||||
return void 0;
|
||||
}
|
||||
return widget.origSerializeValue ? widget.origSerializeValue() : widget.value;
|
||||
};
|
||||
if (widget.linkedWidgets) {
|
||||
for (const w of widget.linkedWidgets) {
|
||||
hideWidget(node, w, ":" + widget.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
__name(hideWidget, "hideWidget");
|
||||
function showWidget(widget) {
|
||||
widget.type = widget.origType;
|
||||
widget.computeSize = widget.origComputeSize;
|
||||
widget.serializeValue = widget.origSerializeValue;
|
||||
delete widget.origType;
|
||||
delete widget.origComputeSize;
|
||||
delete widget.origSerializeValue;
|
||||
if (widget.linkedWidgets) {
|
||||
for (const w of widget.linkedWidgets) {
|
||||
showWidget(w);
|
||||
}
|
||||
}
|
||||
}
|
||||
__name(showWidget, "showWidget");
|
||||
function convertToInput(node, widget, config) {
|
||||
hideWidget(node, widget);
|
||||
const { type } = getWidgetType(config);
|
||||
const sz = node.size;
|
||||
const inputIsOptional = !!widget.options?.inputIsOptional;
|
||||
const input = node.addInput(widget.name, type, {
|
||||
widget: { name: widget.name, [GET_CONFIG]: () => config },
|
||||
...inputIsOptional ? { shape: LiteGraph.SlotShape.HollowCircle } : {}
|
||||
});
|
||||
for (const widget2 of node.widgets) {
|
||||
widget2.last_y += LiteGraph.NODE_SLOT_HEIGHT;
|
||||
}
|
||||
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]);
|
||||
return input;
|
||||
}
|
||||
__name(convertToInput, "convertToInput");
|
||||
function convertToWidget(node, widget) {
|
||||
showWidget(widget);
|
||||
const sz = node.size;
|
||||
node.removeInput(node.inputs.findIndex((i) => i.widget?.name === widget.name));
|
||||
for (const widget2 of node.widgets) {
|
||||
widget2.last_y -= LiteGraph.NODE_SLOT_HEIGHT;
|
||||
}
|
||||
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]);
|
||||
}
|
||||
__name(convertToWidget, "convertToWidget");
|
||||
function getWidgetType(config) {
|
||||
let type = config[0];
|
||||
if (type instanceof Array) {
|
||||
type = "COMBO";
|
||||
}
|
||||
return { type };
|
||||
}
|
||||
__name(getWidgetType, "getWidgetType");
|
||||
function isValidCombo(combo, obj) {
|
||||
if (!(obj instanceof Array)) {
|
||||
console.log(`connection rejected: tried to connect combo to ${obj}`);
|
||||
return false;
|
||||
}
|
||||
if (combo.length !== obj.length) {
|
||||
console.log(`connection rejected: combo lists dont match`);
|
||||
return false;
|
||||
}
|
||||
if (combo.find((v, i) => obj[i] !== v)) {
|
||||
console.log(`connection rejected: combo lists dont match`);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
__name(isValidCombo, "isValidCombo");
|
||||
function isPrimitiveNode(node) {
|
||||
return node.type === "PrimitiveNode";
|
||||
}
|
||||
__name(isPrimitiveNode, "isPrimitiveNode");
|
||||
function setWidgetConfig(slot, config, target) {
|
||||
if (!slot.widget) return;
|
||||
if (config) {
|
||||
slot.widget[GET_CONFIG] = () => config;
|
||||
slot.widget[TARGET] = target;
|
||||
} else {
|
||||
delete slot.widget;
|
||||
}
|
||||
if (slot.link) {
|
||||
const link = app.graph.links[slot.link];
|
||||
if (link) {
|
||||
const originNode = app.graph.getNodeById(link.origin_id);
|
||||
if (isPrimitiveNode(originNode)) {
|
||||
if (config) {
|
||||
originNode.recreateWidget();
|
||||
} else if (!app.configuringGraph) {
|
||||
originNode.disconnectOutput(0);
|
||||
originNode.onLastDisconnect();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__name(setWidgetConfig, "setWidgetConfig");
|
||||
function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
|
||||
if (!config1) {
|
||||
config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
|
||||
}
|
||||
if (config1[0] instanceof Array) {
|
||||
if (!isValidCombo(config1[0], config2[0])) return;
|
||||
} else if (config1[0] !== config2[0]) {
|
||||
console.log(`connection rejected: types dont match`, config1[0], config2[0]);
|
||||
return;
|
||||
}
|
||||
const keys = /* @__PURE__ */ new Set([
|
||||
...Object.keys(config1[1] ?? {}),
|
||||
...Object.keys(config2[1] ?? {})
|
||||
]);
|
||||
let customConfig;
|
||||
const getCustomConfig = /* @__PURE__ */ __name(() => {
|
||||
if (!customConfig) {
|
||||
if (typeof structuredClone === "undefined") {
|
||||
customConfig = JSON.parse(JSON.stringify(config1[1] ?? {}));
|
||||
} else {
|
||||
customConfig = structuredClone(config1[1] ?? {});
|
||||
}
|
||||
}
|
||||
return customConfig;
|
||||
}, "getCustomConfig");
|
||||
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
|
||||
for (const k of keys.values()) {
|
||||
if (k !== "default" && k !== "forceInput" && k !== "defaultInput" && k !== "control_after_generate" && k !== "multiline" && k !== "tooltip") {
|
||||
let v1 = config1[1][k];
|
||||
let v2 = config2[1]?.[k];
|
||||
if (v1 === v2 || !v1 && !v2) continue;
|
||||
if (isNumber) {
|
||||
if (k === "min") {
|
||||
const theirMax = config2[1]?.["max"];
|
||||
if (theirMax != null && v1 > theirMax) {
|
||||
console.log("connection rejected: min > max", v1, theirMax);
|
||||
return;
|
||||
}
|
||||
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2);
|
||||
continue;
|
||||
} else if (k === "max") {
|
||||
const theirMin = config2[1]?.["min"];
|
||||
if (theirMin != null && v1 < theirMin) {
|
||||
console.log("connection rejected: max < min", v1, theirMin);
|
||||
return;
|
||||
}
|
||||
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2);
|
||||
continue;
|
||||
} else if (k === "step") {
|
||||
let step;
|
||||
if (v1 == null) {
|
||||
step = v2;
|
||||
} else if (v2 == null) {
|
||||
step = v1;
|
||||
} else {
|
||||
if (v1 < v2) {
|
||||
const a = v2;
|
||||
v2 = v1;
|
||||
v1 = a;
|
||||
}
|
||||
if (v1 % v2) {
|
||||
console.log(
|
||||
"connection rejected: steps not divisible",
|
||||
"current:",
|
||||
v1,
|
||||
"new:",
|
||||
v2
|
||||
);
|
||||
return;
|
||||
}
|
||||
step = v1;
|
||||
}
|
||||
getCustomConfig()[k] = step;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
console.log(`connection rejected: config ${k} values dont match`, v1, v2);
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (customConfig || forceUpdate) {
|
||||
if (customConfig) {
|
||||
output.widget[CONFIG] = [config1[0], customConfig];
|
||||
}
|
||||
const widget = recreateWidget?.call(this);
|
||||
if (widget) {
|
||||
const min = widget.options.min;
|
||||
const max = widget.options.max;
|
||||
if (min != null && widget.value < min) widget.value = min;
|
||||
if (max != null && widget.value > max) widget.value = max;
|
||||
widget.callback(widget.value);
|
||||
}
|
||||
}
|
||||
return { customConfig };
|
||||
}
|
||||
__name(mergeIfValid, "mergeIfValid");
|
||||
let useConversionSubmenusSetting;
|
||||
app.registerExtension({
|
||||
name: "Comfy.WidgetInputs",
|
||||
init() {
|
||||
useConversionSubmenusSetting = app.ui.settings.addSetting({
|
||||
id: "Comfy.NodeInputConversionSubmenus",
|
||||
name: "In the node context menu, place the entries that convert between input/widget in sub-menus.",
|
||||
type: "boolean",
|
||||
defaultValue: true
|
||||
});
|
||||
},
|
||||
async beforeRegisterNodeDef(nodeType, nodeData, app2) {
|
||||
const origGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions;
|
||||
nodeType.prototype.convertWidgetToInput = function(widget) {
|
||||
const config = getConfig.call(this, widget.name) ?? [
|
||||
widget.type,
|
||||
widget.options || {}
|
||||
];
|
||||
if (!isConvertibleWidget(widget, config)) return false;
|
||||
if (widget.type?.startsWith(CONVERTED_TYPE)) return false;
|
||||
convertToInput(this, widget, config);
|
||||
return true;
|
||||
};
|
||||
nodeType.prototype.getExtraMenuOptions = function(_, options) {
|
||||
const r = origGetExtraMenuOptions ? origGetExtraMenuOptions.apply(this, arguments) : void 0;
|
||||
if (this.widgets) {
|
||||
let toInput = [];
|
||||
let toWidget = [];
|
||||
for (const w of this.widgets) {
|
||||
if (w.options?.forceInput) {
|
||||
continue;
|
||||
}
|
||||
if (w.type === CONVERTED_TYPE) {
|
||||
toWidget.push({
|
||||
content: `Convert ${w.name} to widget`,
|
||||
callback: /* @__PURE__ */ __name(() => convertToWidget(this, w), "callback")
|
||||
});
|
||||
} else {
|
||||
const config = getConfig.call(this, w.name) ?? [
|
||||
w.type,
|
||||
w.options || {}
|
||||
];
|
||||
if (isConvertibleWidget(w, config)) {
|
||||
toInput.push({
|
||||
content: `Convert ${w.name} to input`,
|
||||
callback: /* @__PURE__ */ __name(() => convertToInput(this, w, config), "callback")
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
if (toInput.length) {
|
||||
if (useConversionSubmenusSetting.value) {
|
||||
options.push({
|
||||
content: "Convert Widget to Input",
|
||||
submenu: {
|
||||
options: toInput
|
||||
}
|
||||
});
|
||||
} else {
|
||||
options.push(...toInput, null);
|
||||
}
|
||||
}
|
||||
if (toWidget.length) {
|
||||
if (useConversionSubmenusSetting.value) {
|
||||
options.push({
|
||||
content: "Convert Input to Widget",
|
||||
submenu: {
|
||||
options: toWidget
|
||||
}
|
||||
});
|
||||
} else {
|
||||
options.push(...toWidget, null);
|
||||
}
|
||||
}
|
||||
}
|
||||
return r;
|
||||
};
|
||||
nodeType.prototype.onGraphConfigured = function() {
|
||||
if (!this.inputs) return;
|
||||
this.widgets ??= [];
|
||||
for (const input of this.inputs) {
|
||||
if (input.widget) {
|
||||
if (!input.widget[GET_CONFIG]) {
|
||||
input.widget[GET_CONFIG] = () => getConfig.call(this, input.widget.name);
|
||||
}
|
||||
if (input.widget.config) {
|
||||
if (input.widget.config[0] instanceof Array) {
|
||||
input.type = "COMBO";
|
||||
const link = app2.graph.links[input.link];
|
||||
if (link) {
|
||||
link.type = input.type;
|
||||
}
|
||||
}
|
||||
delete input.widget.config;
|
||||
}
|
||||
const w = this.widgets.find((w2) => w2.name === input.widget.name);
|
||||
if (w) {
|
||||
hideWidget(this, w);
|
||||
} else {
|
||||
convertToWidget(this, input);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
|
||||
nodeType.prototype.onNodeCreated = function() {
|
||||
const r = origOnNodeCreated ? origOnNodeCreated.apply(this) : void 0;
|
||||
if (!app2.configuringGraph && this.widgets) {
|
||||
for (const w of this.widgets) {
|
||||
if (w?.options?.forceInput || w?.options?.defaultInput) {
|
||||
const config = getConfig.call(this, w.name) ?? [
|
||||
w.type,
|
||||
w.options || {}
|
||||
];
|
||||
convertToInput(this, w, config);
|
||||
}
|
||||
}
|
||||
}
|
||||
return r;
|
||||
};
|
||||
const origOnConfigure = nodeType.prototype.onConfigure;
|
||||
nodeType.prototype.onConfigure = function() {
|
||||
const r = origOnConfigure ? origOnConfigure.apply(this, arguments) : void 0;
|
||||
if (!app2.configuringGraph && this.inputs) {
|
||||
for (const input of this.inputs) {
|
||||
if (input.widget && !input.widget[GET_CONFIG]) {
|
||||
input.widget[GET_CONFIG] = () => getConfig.call(this, input.widget.name);
|
||||
const w = this.widgets.find((w2) => w2.name === input.widget.name);
|
||||
if (w) {
|
||||
hideWidget(this, w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return r;
|
||||
};
|
||||
function isNodeAtPos(pos) {
|
||||
for (const n of app2.graph.nodes) {
|
||||
if (n.pos[0] === pos[0] && n.pos[1] === pos[1]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
__name(isNodeAtPos, "isNodeAtPos");
|
||||
const origOnInputDblClick = nodeType.prototype.onInputDblClick;
|
||||
const ignoreDblClick = Symbol();
|
||||
nodeType.prototype.onInputDblClick = function(slot) {
|
||||
const r = origOnInputDblClick ? origOnInputDblClick.apply(this, arguments) : void 0;
|
||||
const input = this.inputs[slot];
|
||||
if (!input.widget || !input[ignoreDblClick]) {
|
||||
if (!(input.type in ComfyWidgets) && !(input.widget[GET_CONFIG]?.()?.[0] instanceof Array)) {
|
||||
return r;
|
||||
}
|
||||
}
|
||||
const node = LiteGraph.createNode("PrimitiveNode");
|
||||
app2.graph.add(node);
|
||||
const pos = [
|
||||
this.pos[0] - node.size[0] - 30,
|
||||
this.pos[1]
|
||||
];
|
||||
while (isNodeAtPos(pos)) {
|
||||
pos[1] += LiteGraph.NODE_TITLE_HEIGHT;
|
||||
}
|
||||
node.pos = pos;
|
||||
node.connect(0, this, slot);
|
||||
node.title = input.name;
|
||||
input[ignoreDblClick] = true;
|
||||
setTimeout(() => {
|
||||
delete input[ignoreDblClick];
|
||||
}, 300);
|
||||
return r;
|
||||
};
|
||||
const onConnectInput = nodeType.prototype.onConnectInput;
|
||||
nodeType.prototype.onConnectInput = function(targetSlot, type, output, originNode, originSlot) {
|
||||
const v = onConnectInput?.(this, arguments);
|
||||
if (type !== "COMBO") return v;
|
||||
if (originNode.outputs[originSlot].widget) return v;
|
||||
const targetCombo = this.inputs[targetSlot].widget?.[GET_CONFIG]?.()?.[0];
|
||||
if (!targetCombo || !(targetCombo instanceof Array)) return v;
|
||||
const originConfig = originNode.constructor?.nodeData?.output?.[originSlot];
|
||||
if (!originConfig || !isValidCombo(targetCombo, originConfig)) {
|
||||
return false;
|
||||
}
|
||||
return v;
|
||||
};
|
||||
},
|
||||
registerCustomNodes() {
|
||||
LiteGraph.registerNodeType(
|
||||
"PrimitiveNode",
|
||||
Object.assign(PrimitiveNode, {
|
||||
title: "Primitive"
|
||||
})
|
||||
);
|
||||
PrimitiveNode.category = "utils";
|
||||
}
|
||||
});
|
||||
window.comfyAPI = window.comfyAPI || {};
|
||||
window.comfyAPI.widgetInputs = window.comfyAPI.widgetInputs || {};
|
||||
window.comfyAPI.widgetInputs.getWidgetConfig = getWidgetConfig;
|
||||
window.comfyAPI.widgetInputs.convertToInput = convertToInput;
|
||||
window.comfyAPI.widgetInputs.setWidgetConfig = setWidgetConfig;
|
||||
window.comfyAPI.widgetInputs.mergeIfValid = mergeIfValid;
|
||||
export {
|
||||
convertToInput,
|
||||
getWidgetConfig,
|
||||
mergeIfValid,
|
||||
setWidgetConfig
|
||||
};
|
||||
//# sourceMappingURL=widgetInputs-DdoWwzg5.js.map
|
||||
1
web/assets/widgetInputs-DdoWwzg5.js.map
generated
vendored
Normal file
1
web/assets/widgetInputs-DdoWwzg5.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
2
web/extensions/core/clipspace.js
vendored
2
web/extensions/core/clipspace.js
vendored
@@ -1,2 +1,2 @@
|
||||
// Shim for extensions\core\clipspace.ts
|
||||
// Shim for extensions/core/clipspace.ts
|
||||
export const ClipspaceDialog = window.comfyAPI.clipspace.ClipspaceDialog;
|
||||
|
||||
3
web/extensions/core/colorPalette.js
vendored
Normal file
3
web/extensions/core/colorPalette.js
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
// Shim for extensions/core/colorPalette.ts
|
||||
export const defaultColorPalette = window.comfyAPI.colorPalette.defaultColorPalette;
|
||||
export const getColorPalette = window.comfyAPI.colorPalette.getColorPalette;
|
||||
2
web/extensions/core/groupNode.js
vendored
2
web/extensions/core/groupNode.js
vendored
@@ -1,3 +1,3 @@
|
||||
// Shim for extensions\core\groupNode.ts
|
||||
// Shim for extensions/core/groupNode.ts
|
||||
export const GroupNodeConfig = window.comfyAPI.groupNode.GroupNodeConfig;
|
||||
export const GroupNodeHandler = window.comfyAPI.groupNode.GroupNodeHandler;
|
||||
|
||||
2
web/extensions/core/groupNodeManage.js
vendored
2
web/extensions/core/groupNodeManage.js
vendored
@@ -1,2 +1,2 @@
|
||||
// Shim for extensions\core\groupNodeManage.ts
|
||||
// Shim for extensions/core/groupNodeManage.ts
|
||||
export const ManageGroupDialog = window.comfyAPI.groupNodeManage.ManageGroupDialog;
|
||||
|
||||
3
web/extensions/core/widgetInputs.js
vendored
3
web/extensions/core/widgetInputs.js
vendored
@@ -1,4 +1,5 @@
|
||||
// Shim for extensions\core\widgetInputs.ts
|
||||
// Shim for extensions/core/widgetInputs.ts
|
||||
export const getWidgetConfig = window.comfyAPI.widgetInputs.getWidgetConfig;
|
||||
export const convertToInput = window.comfyAPI.widgetInputs.convertToInput;
|
||||
export const setWidgetConfig = window.comfyAPI.widgetInputs.setWidgetConfig;
|
||||
export const mergeIfValid = window.comfyAPI.widgetInputs.mergeIfValid;
|
||||
|
||||
92
web/index.html
vendored
92
web/index.html
vendored
@@ -1,50 +1,42 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>ComfyUI</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
|
||||
<!-- Browser Test Fonts -->
|
||||
<!-- <link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet">
|
||||
<link href="https://fonts.googleapis.com/css2?family=Noto+Color+Emoji&family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet">
|
||||
<style>
|
||||
* {
|
||||
font-family: 'Roboto Mono', 'Noto Color Emoji';
|
||||
}
|
||||
</style> -->
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" type="text/css" href="user.css" />
|
||||
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
|
||||
<script type="module" crossorigin src="./assets/index-CI3N807S.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="./assets/index-_5czGnTA.css">
|
||||
</head>
|
||||
<body class="litegraph">
|
||||
<div id="vue-app"></div>
|
||||
<div id="comfy-user-selection" class="comfy-user-selection" style="display: none;">
|
||||
<main class="comfy-user-selection-inner">
|
||||
<h1>ComfyUI</h1>
|
||||
<form>
|
||||
<section>
|
||||
<label>New user:
|
||||
<input placeholder="Enter a username" />
|
||||
</label>
|
||||
</section>
|
||||
<div class="comfy-user-existing">
|
||||
<span class="or-separator">OR</span>
|
||||
<section>
|
||||
<label>
|
||||
Existing user:
|
||||
<select>
|
||||
<option hidden disabled selected value> Select a user </option>
|
||||
</select>
|
||||
</label>
|
||||
</section>
|
||||
</div>
|
||||
<footer>
|
||||
<span class="comfy-user-error"> </span>
|
||||
<button class="comfy-btn comfy-user-button-next">Next</button>
|
||||
</footer>
|
||||
</form>
|
||||
</main>
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>ComfyUI</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
|
||||
<link rel="stylesheet" type="text/css" href="user.css" />
|
||||
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
|
||||
<script type="module" crossorigin src="./assets/index-DGAbdBYF.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="./assets/index-BHJGjcJh.css">
|
||||
</head>
|
||||
<body class="litegraph grid">
|
||||
<div id="vue-app"></div>
|
||||
<div id="comfy-user-selection" class="comfy-user-selection" style="display: none;">
|
||||
<main class="comfy-user-selection-inner">
|
||||
<h1>ComfyUI</h1>
|
||||
<form>
|
||||
<section>
|
||||
<label>New user:
|
||||
<input placeholder="Enter a username" />
|
||||
</label>
|
||||
</section>
|
||||
<div class="comfy-user-existing">
|
||||
<span class="or-separator">OR</span>
|
||||
<section>
|
||||
<label>
|
||||
Existing user:
|
||||
<select>
|
||||
<option hidden disabled selected value> Select a user </option>
|
||||
</select>
|
||||
</label>
|
||||
</section>
|
||||
</div>
|
||||
<footer>
|
||||
<span class="comfy-user-error"> </span>
|
||||
<button class="comfy-btn comfy-user-button-next">Next</button>
|
||||
</footer>
|
||||
</form>
|
||||
</main>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user