Compare commits

...

112 Commits

Author SHA1 Message Date
Dr.Lt.Data
5f9d5a244b Hotfix for the div zero occurrence when memory_used_encode is 0 (#5121)
https://github.com/comfyanonymous/ComfyUI/issues/5069#issuecomment-2382656368
2024-10-09 23:34:34 -04:00
Chenlei Hu
14eba07acd Update web content to release v1.3.11 (#5189)
* Update web content to release v1.3.11

* nit
2024-10-09 22:37:04 -04:00
Jonathan Avila
4b2f0d9413 Increase maximum macOS version to 15.0.1 when forcing upcast attention (#5191) 2024-10-09 22:21:41 -04:00
Yoland Yan
25eac1d780 Change runner label for the new runners (#5197) 2024-10-09 20:08:57 -04:00
comfyanonymous
e38c94228b Add a weight_dtype fp8_e4m3fn_fast to the Diffusion Model Loader node.
This is used to load weights in fp8 and use fp8 matrix multiplication.
2024-10-09 19:43:17 -04:00
comfyanonymous
203942c8b2 Fix flux doras with diffusers keys. 2024-10-08 19:03:40 -04:00
Brendan Hoar
3c72c89a52 Update folder_paths.py - try/catch for special file_name values (#5187)
Somehow managed to drop a file called "nul" into a windows checkpoints subdirectory. This caused all sorts of havoc with many nodes that needed the list of checkpoints.
2024-10-08 15:04:32 -04:00
Chenlei Hu
614377abd6 Update web content to release v1.2.64 (#5124) 2024-10-07 17:15:29 -04:00
comfyanonymous
8dfa0cc552 Make SD3 fast previews a little better. 2024-10-07 09:19:59 -04:00
comfyanonymous
e5ecdfdd2d Make fast previews for SDXL a little better by adding a bias. 2024-10-06 19:27:04 -04:00
comfyanonymous
7d29fbf74b Slightly improve the fast previews for flux by adding a bias. 2024-10-06 17:55:46 -04:00
Lex
2c641e64ad IS_CHANGED should be a classmethod (#5159) 2024-10-06 05:47:51 -04:00
comfyanonymous
7d2467e830 Some minor cleanups. 2024-10-05 13:22:39 -04:00
comfyanonymous
6f021d8aa0 Let --verbose have an argument for the log level. 2024-10-04 10:05:34 -04:00
comfyanonymous
d854ed0bcf Allow using SD3 type te output on flux model. 2024-10-03 09:44:54 -04:00
comfyanonymous
abcd006b8c Allow more permutations of clip/t5 in dual clip loader. 2024-10-03 09:26:11 -04:00
comfyanonymous
d985d1d7dc CLIP Loader node now supports clip_l and clip_g only for SD3. 2024-10-02 04:25:17 -04:00
comfyanonymous
d1cdf51e1b Refactor some of the TE detection code. 2024-10-01 07:08:41 -04:00
comfyanonymous
b4626ab93e Add simpletuner lycoris format for SD unet. 2024-09-30 06:03:27 -04:00
comfyanonymous
a9e459c2a4 Use torch.nn.functional.linear in RGB preview code.
Add an optional bias to the latent RGB preview code.
2024-09-29 11:27:49 -04:00
comfyanonymous
3bb4dec720 Fix issue with loras, lowvram and --fast fp8. 2024-09-28 14:42:32 -04:00
City
8733191563 Flux torch.compile fix (#5082) 2024-09-27 22:07:51 -04:00
comfyanonymous
83b01f960a Add backend option to TorchCompileModel.
If you want to use the cudagraphs backend you need to: --disable-cuda-malloc

If you get other backends working feel free to make a PR to add them.
2024-09-27 02:12:37 -04:00
comfyanonymous
d72e871cfa Add a note that the experimental model downloader api will be removed. 2024-09-26 03:17:52 -04:00
comfyanonymous
037c3159b6 Move some nodes out of _for_testing. 2024-09-25 08:41:22 -04:00
comfyanonymous
bdd4a22a2e Fix flux TE not loading t5 embeddings. 2024-09-24 22:57:22 -04:00
comfyanonymous
fdf37566ef Add batch size to EmptyLatentAudio. 2024-09-24 04:32:55 -04:00
Alex "mcmonkey" Goodwin
08c8968482 Internal download API: Add proper validated directory input (#4981)
* add internal /folder_paths route

returns a json maps of folder paths

* (minor) format download_models.py

* initial folder path input on download api

* actually, require folder_path and clean up some code

* partial tests update

* fix & logging

* also download to a tmp file not the live file

to avoid compounding errors from network failure

* update tests again

* test tweaks

* workaround the first tests blocker

* fix file handling in tests

* rewrite test for create_model_path

* minor doc fix

* avoid 'mock_directory'

use temp dir to avoid accidental fs pollution from tests
2024-09-24 03:50:45 -04:00
chaObserv
479a427a48 Add dpmpp_2m_cfg_pp (#4992) 2024-09-24 02:42:56 -04:00
comfyanonymous
3a0eeee320 Make --listen listen on both ipv4 and ipv6 at the same time by default. 2024-09-23 04:38:19 -04:00
comfyanonymous
447da7ea86 Support listening on multiple addresses. 2024-09-23 04:36:59 -04:00
comfyanonymous
9c41bc8d10 Remove useless line. 2024-09-23 02:32:29 -04:00
Robin Huang
6ad0ddbae4 Run unit tests on Windows/MacOS as well. (#5018)
* Run unit tests on Windows as well.

* Test on mac.

* Continue running on error.

* Compared normalized paths to work cross platform.

* Only test common set of mimetypes across operating systems.
2024-09-22 05:01:39 -04:00
RandomGitUser321
a55142f904 Add ws.close() to the websocket examples (#5020)
* add ws.close() to websocket examples

* add and explain ws.close() in websocket examples
2024-09-22 04:59:10 -04:00
comfyanonymous
5718ef69bb Add total and free ram to /system_stats. 2024-09-22 03:42:11 -04:00
RandomGitUser321
13ecf10a92 Added to the websockets_api_example.py to show how to decode latent previews from the binary stream (#5016)
* Update websockets_api_example.py

* even more simplfied
2024-09-22 02:30:44 -04:00
comfyanonymous
7a415f47a9 Add an optional VAE input to the ControlNetApplyAdvanced node.
Deprecate the other controlnet nodes.
2024-09-22 01:24:52 -04:00
Chenlei Hu
89fa2fca24 Update web content to release v1.2.60 (#5017)
* Update web content to release v1.2.60

* Remove dist.zip
2024-09-21 23:28:54 -04:00
comfyanonymous
364b69e931 Make SD3 empty latent image zeros.
This shouldn't change anything. The reason it was not zeros is because it
did matter in early versions of the code.
2024-09-21 09:13:10 -04:00
comfyanonymous
dc96a1ae19 Load controlnet in fp8 if weights are in fp8. 2024-09-21 04:50:12 -04:00
comfyanonymous
2d810b081e Add load_controlnet_state_dict function. 2024-09-21 01:51:51 -04:00
comfyanonymous
9f7e9f0547 Add an error message when a controlnet needs a VAE but none is given. 2024-09-21 01:33:18 -04:00
comfyanonymous
a355f38ecc Make the SD3 controlnet node the default one. 2024-09-21 01:32:46 -04:00
huchenlei
38c69080c7 Add docstring 2024-09-20 03:16:23 -04:00
comfyanonymous
70a708d726 Fix model merging issue. 2024-09-20 02:31:44 -04:00
yoinked
e7d4782736 add laplace scheduler [2407.03297] (#4990)
* add laplace scheduler [2407.03297]

* should be here instead lol

* better settings
2024-09-19 23:23:09 -04:00
Alex "mcmonkey" Goodwin
3326bdfd4e add internal /folder_paths route (#4980)
returns a json maps of folder paths
2024-09-19 09:52:55 -04:00
Alex "mcmonkey" Goodwin
68bb885d22 add 'is_default' to model paths config (#4979)
* add 'is_default' to model paths config

including impl and doc in example file

* update weirdly overspecific test expectations

* oh there's two

* sigh
2024-09-19 08:59:55 -04:00
comfyanonymous
ad66f7c7d8 Add model_options to load_controlnet function. 2024-09-19 08:23:35 -04:00
Simon Lui
de8e8e3b0d Fix xpu Pytorch nightly build from calling optimize which doesn't exist. (#4978) 2024-09-19 05:11:42 -04:00
Alex "mcmonkey" Goodwin
a1e71cfad1 very simple strong-cache on model list (#4969)
* very simple strong-cache on model list

* store the cache after validation too

* only cache object_info for now

* use a 'with' context
2024-09-19 04:40:14 -04:00
comfyanonymous
0bfc7cc998 Create the temp directory on ComfyUI startup instead. 2024-09-18 09:55:57 -04:00
Tom
7183fd1665 Add route to list model types (#4846)
* Add list models route

* Better readable model types list
2024-09-17 04:22:05 -04:00
Alex "mcmonkey" Goodwin
254838f23c add simple error check to model loading (#4950) 2024-09-17 03:57:17 -04:00
pharmapsychotic
0b7dfa986d Improve tiling calculations to reduce number of tiles that need to be processed. (#4944) 2024-09-17 03:51:10 -04:00
comfyanonymous
d514bb38ee Add some option to model_options for the text encoder.
load_device, offload_device and the initial_device can now be set.
2024-09-17 03:49:54 -04:00
comfyanonymous
0849c80e2a get_key_patches now works without unloading the model. 2024-09-17 01:57:59 -04:00
comfyanonymous
56e8f5e4fd VAEDecodeAudio now does some normalization on the audio. 2024-09-16 00:30:36 -04:00
comfyanonymous
e813abbb2c Long CLIP L support for SDXL, SD3 and Flux.
Use the *CLIPLoader nodes.
2024-09-15 07:59:38 -04:00
JettHu
5e68a4ce67 Reduce repeated calls of INPUT_TYPES in cache (#4922) 2024-09-15 01:03:09 -04:00
comfyanonymous
ca08597670 Make the inpaint controlnet node work with non inpaint ones. 2024-09-14 09:17:13 -04:00
comfyanonymous
f48e390032 Support AliMama SD3 and Flux inpaint controlnets.
Use the ControlNetInpaintingAliMamaApply node.
2024-09-14 09:05:16 -04:00
Chenlei Hu
369a6dd2c4 Remove empty spaces in user_manager.py (#4917) 2024-09-13 23:30:44 -04:00
comfyanonymous
b3ce8fb9fd Revert "Reduce repeated calls of get_immediate_node_signature for ancestors in cache (#4871)"
This reverts commit f6b7194f64.
2024-09-13 23:24:47 -04:00
comfyanonymous
cf80d28689 Support loading controlnets with different input. 2024-09-13 09:54:37 -04:00
Acly
6fb44c4b7c Make adding links/nodes to ExecutionList non-recursive (#4886)
Graphs with 300+ chained nodes run into maximum recursion depth error (limit is 1000 in CPython)
2024-09-13 08:25:11 -04:00
Chenlei Hu
d2247c1e61 Normalize path returned by /userdata to always use / as separator (#4906) 2024-09-13 03:45:31 -04:00
Chenlei Hu
cb12ad7049 Add full_info flag in /userdata endpoint to list out file size and last modified timestamp (#4905)
* Add full_info flag in /userdata endpoint to list out file size and last modified timestamp

* nit
2024-09-13 02:40:59 -04:00
JettHu
f6b7194f64 Reduce repeated calls of get_immediate_node_signature for ancestors in cache (#4871) 2024-09-12 23:02:52 -04:00
comfyanonymous
7c6eb4fb29 Set some nodes as DEPRECATED. 2024-09-12 20:27:07 -04:00
Robin Huang
b962db9952 Add cli arg to override user directory (#4856)
* Override user directory.

* Use overridden user directory.

* Remove prints.

* Remove references to global user_files.

* Remove unused replace_folder function.

* Remove newline.

* Remove global during get_user_directory.

* Add validation.
2024-09-12 08:10:27 -04:00
comfyanonymous
d0b7ab88ba Add a simple experimental TorchCompileModel node.
It probably only works on Linux.

For maximum speed on Flux with Nvidia 40 series/ada and newer try using
this node with fp8_e4m3fn and the --fast argument.
2024-09-12 05:24:25 -04:00
Yoland Yan
405b529545 Minor: update tests-unit README.md (#4896) 2024-09-12 04:53:08 -04:00
comfyanonymous
9d720187f1 types -> comfy_types to fix import issue. 2024-09-12 03:57:46 -04:00
Robin Huang
d247bc5a9c Expand variables in base_path for extra_config_paths.yaml. (#4893)
* Expand variables in base_path for extra_config_paths.yaml.

* Fix comments.
2024-09-12 01:52:06 -04:00
comfyanonymous
9f4daca9d9 Doesn't really make sense for cfg_pp sampler to call regular one. 2024-09-11 02:51:36 -04:00
yoinked
b5d0f2a908 Add CFG++ to DPM++ 2S Ancestral (#3871)
* Update sampling.py

* Update samplers.py

* my bad

* "fix" the sampler

* Update samplers.py

* i named it wrong

* minor sampling improvements

mainly using a dynamic rho value (hey this sounds a lot like smea!!!)

* revert rho change

rho? r? its just 1/2
2024-09-11 02:49:44 -04:00
bymyself
e760bf5c40 Add content-type filter method to folder_paths (#4054)
* Add content-type filter method to folder_paths

* Add unit tests

* Hardcode webp content-type

* Annotate content_types as Literal["image", "video", "audio"]
2024-09-11 02:00:07 -04:00
comfyanonymous
36c83cdbba Limit origin check to when host is loopback.
This should still prevent the exploit without breaking things for people
who use reverse proxies.
2024-09-11 01:06:37 -04:00
Yoland Yan
81778a7feb [🗻 Mount Fuji Commit] Add unit tests for folder path utilities (#4869)
All past 30 min of comtts are done on the top of Mt Fuji
By Comfy, Robin, and Yoland
All other comfy org members died on the way

Introduced unit tests to verify the correctness of various folder path
utility functions such as `get_directory_by_type`, `annotated_filepath`,
and `recursive_search` among others. These tests cover scenarios
including directory retrieval, filepath annotation, recursive file
searches, and filtering files by extensions, enhancing the robustness
and reliability of the codebase.
2024-09-10 00:44:49 -04:00
comfyanonymous
bc94662b31 Cleanup. 2024-09-10 00:43:37 -04:00
Robin Huang
9fa8faa44a Expand user directory for basepath in extra_models_paths.yaml (#4857)
* Expand user path.

* Add test.

* Add unit test for expanding base path.

* Simplify unit test.

* Remove comment.

* Remove comment.

* Checkpoints.

* Refactor.
2024-09-10 00:33:44 -04:00
comfyanonymous
9a7444e39f Add diffusion_models to the extra_model_paths.yaml.example 2024-09-10 00:21:33 -04:00
comfyanonymous
54fca4a218 If host does not contain a port only compare the hostnames. 2024-09-09 16:28:23 -04:00
Chenlei Hu
cd4955367e Add back CI action for tests-ui (#4859) 2024-09-09 04:32:55 -04:00
david02871
8354203d95 Add .venv to gitignore (#4756) 2024-09-09 04:31:18 -04:00
comfyanonymous
e0b41243b4 Fix issue where sometimes origin doesn't contain the port. 2024-09-09 03:18:17 -04:00
Alex "mcmonkey" Goodwin
619263d4a6 allow current timestamp in save image prefix (#4030) 2024-09-09 02:55:51 -04:00
comfyanonymous
e3b0402bb7 Ignore origin domain when it's empty. 2024-09-09 01:04:56 -04:00
Darion
967867d48c fix: url decode filename from API (#4801) 2024-09-08 21:02:32 -04:00
comfyanonymous
cbaac71bf5 Fix issue with last commit. 2024-09-08 19:35:23 -04:00
comfyanonymous
3ab3516e46 By default only accept requests where origin header matches the host.
Browsers are dumb and let any website do requests to localhost this should
prevent this without breaking things. CORS prevents the javascript from
reading the response but they can still write it.

At the moment this is only enabled when the --enable-cors-header argument
is not used.
2024-09-08 18:17:29 -04:00
comfyanonymous
9c5fca75f4 Fix lora issue. 2024-09-08 10:10:47 -04:00
guill
a5da4d0b3e Fix error with ExecutionBlocker and OUTPUT_IS_LIST (#4836)
This change resolves an error when a node with OUTPUT_IS_LIST=(True,)
receives an ExecutionBlocker. I've also added a unit test for this case.
2024-09-08 09:48:47 -04:00
comfyanonymous
32a60a7bac Support onetrainer text encoder Flux lora. 2024-09-08 09:31:41 -04:00
Jim Winkens
bb52934ba4 Fix import issue (#4815) 2024-09-07 05:28:32 -04:00
comfyanonymous
8aabd7c8c0 SaveLora node can now save "full diff" lora format.
This isn't actually a lora format and is saving the full diff of the
weights in a format that can be used in the lora loader nodes.
2024-09-07 03:21:02 -04:00
comfyanonymous
a09b29ca11 Add an option to the SaveLora node to store the bias diff. 2024-09-07 03:03:30 -04:00
comfyanonymous
9bfee68773 LoraSave node now supports generating text encoder loras.
text_encoder_diff should be connected to a CLIPMergeSubtract node.

model_diff and text_encoder_diff are optional inputs so you can create
model only loras, text encoder only loras or a lora that contains both.
2024-09-07 02:30:12 -04:00
comfyanonymous
ea77750759 Support a generic Comfy format for text encoder loras.
This is a format with keys like:
text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.lora_up.weight

Instead of waiting for me to add support for specific lora formats you can
convert your text encoder loras to this format instead.

If you want to see an example save a text encoder lora with the SaveLora
node with the commit right after this one.
2024-09-07 02:20:39 -04:00
comfyanonymous
c27ebeb1c2 Fix onnx export not working on flux. 2024-09-06 03:21:52 -04:00
guill
0c7c98a965 Nodes using UNIQUE_ID as input are NOT_IDEMPOTENT (#4793)
As suggested by @ltdrdata, we can automatically consider nodes that take
the UNIQUE_ID hidden input to be NOT_IDEMPOTENT.
2024-09-05 19:33:02 -04:00
comfyanonymous
dc2eb75b85 Update stable release workflow to latest pytorch with cuda 12.4. 2024-09-05 19:21:52 -04:00
Chenlei Hu
fa34efe3bd Update frontend to v1.2.47 (#4798)
* Update web content to release v1.2.47

* Update shortcut list
2024-09-05 18:56:01 -04:00
comfyanonymous
5cbaa9e07c Mistoline flux controlnet support. 2024-09-05 00:05:17 -04:00
comfyanonymous
c7427375ee Prioritize freeing partially offloaded models first. 2024-09-04 19:47:32 -04:00
comfyanonymous
22d1241a50 Add an experimental LoraSave node to extract model loras.
The model_diff input should be connected to the output of a
ModelMergeSubtract node.
2024-09-04 16:38:38 -04:00
Jedrzej Kosinski
f04229b84d Add emb_patch support to UNetModel forward (#4779) 2024-09-04 14:35:15 -04:00
Silver
f067ad15d1 Make live preview size a configurable launch argument (#4649)
* Make live preview size a configurable launch argument

* Remove import from testing phase

* Update cli_args.py
2024-09-03 19:16:38 -04:00
comfyanonymous
483004dd1d Support newer glora format. 2024-09-03 17:02:19 -04:00
comfyanonymous
00a5d08103 Lower fp8 lora memory usage. 2024-09-03 01:25:05 -04:00
comfyanonymous
d043997d30 Flux onetrainer lora. 2024-09-02 08:22:15 -04:00
139 changed files with 111976 additions and 87396 deletions

View File

@@ -23,7 +23,7 @@ jobs:
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, win]
runner_label: [self-hosted, Windows]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:

View File

@@ -12,7 +12,7 @@ on:
description: 'CUDA version'
required: true
type: string
default: "121"
default: "124"
python_minor:
description: 'Python minor version'
required: true

View File

@@ -32,7 +32,7 @@ jobs:
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, win]
runner_label: [self-hosted, Windows]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
@@ -55,7 +55,7 @@ jobs:
torch_version: ["nightly"]
include:
- os: windows
runner_label: [self-hosted, win]
runner_label: [self-hosted, Windows]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:

30
.github/workflows/test-unit.yml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: Unit Tests
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.os }}
continue-on-error: true
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Run Unit Tests
run: |
pip install -r tests-unit/requirements.txt
python -m pytest tests-unit

1
.gitignore vendored
View File

@@ -12,6 +12,7 @@ extra_model_paths.yaml
.vscode/
.idea/
venv/
.venv/
/web/extensions/*
!/web/extensions/logging.js.example
!/web/extensions/core/

View File

@@ -94,6 +94,8 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Alt + `+` | Canvas Zoom in |
| Alt + `-` | Canvas Zoom out |
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
| P | Pin/Unpin selected nodes |
| Ctrl + G | Group selected nodes |
| Q | Toggle visibility of the queue |
| H | Toggle visibility of history |
| R | Refresh graph |

View File

@@ -1,6 +1,6 @@
from aiohttp import web
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
from api_server.services.file_service import FileService
import app.logger
@@ -36,6 +36,13 @@ class InternalRoutes:
async def get_logs(request):
return web.json_response(app.logger.get_logs())
@self.routes.get('/folder_paths')
async def get_folder_paths(request):
response = {}
for key in folder_names_and_paths:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)
def get_app(self):
if self._app is None:
self._app = web.Application()

View File

@@ -10,14 +10,14 @@ def get_logs():
return "\n".join([formatter.format(x) for x in logs])
def setup_logger(verbose: bool = False, capacity: int = 300):
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
global logs
if logs:
return
# Setup default global logger
logger = logging.getLogger()
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
logger.setLevel(log_level)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s"))

View File

@@ -5,17 +5,17 @@ import uuid
import glob
import shutil
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
from folder_paths import user_directory
import folder_paths
from .app_settings import AppSettings
default_user = "default"
users_file = os.path.join(user_directory, "users.json")
class UserManager():
def __init__(self):
global user_directory
user_directory = folder_paths.get_user_directory()
self.settings = AppSettings(self)
if not os.path.exists(user_directory):
@@ -25,14 +25,17 @@ class UserManager():
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user:
if os.path.isfile(users_file):
with open(users_file) as f:
if os.path.isfile(self.get_users_file()):
with open(self.get_users_file()) as f:
self.users = json.load(f)
else:
self.users = {}
else:
self.users = {"default": "default"}
def get_users_file(self):
return os.path.join(folder_paths.get_user_directory(), "users.json")
def get_request_user_id(self, request):
user = "default"
if args.multi_user and "comfy-user" in request.headers:
@@ -44,7 +47,7 @@ class UserManager():
return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
global user_directory
user_directory = folder_paths.get_user_directory()
if type == "userdata":
root_dir = user_directory
@@ -59,6 +62,10 @@ class UserManager():
return None
if file is not None:
# Check if filename is url encoded
if "%" in file:
file = parse.unquote(file)
# prevent leaving /{type}/{user}
path = os.path.abspath(os.path.join(user_root, file))
if os.path.commonpath((user_root, path)) != user_root:
@@ -80,8 +87,7 @@ class UserManager():
self.users[user_id] = name
global users_file
with open(users_file, "w") as f:
with open(self.get_users_file(), "w") as f:
json.dump(self.users, f)
return user_id
@@ -112,25 +118,69 @@ class UserManager():
@routes.get("/userdata")
async def listuserdata(request):
"""
List user data files in a specified directory.
This endpoint allows listing files in a user's data directory, with options for recursion,
full file information, and path splitting.
Query Parameters:
- dir (required): The directory to list files from.
- recurse (optional): If "true", recursively list files in subdirectories.
- full_info (optional): If "true", return detailed file information (path, size, modified time).
- split (optional): If "true", split file paths into components (only applies when full_info is false).
Returns:
- 400: If 'dir' parameter is missing.
- 403: If the requested path is not allowed.
- 404: If the requested directory does not exist.
- 200: JSON response with the list of files or file information.
The response format depends on the query parameters:
- Default: List of relative file paths.
- full_info=true: List of dictionaries with file details.
- split=true (and full_info=false): List of lists, each containing path components.
"""
directory = request.rel_url.query.get('dir', '')
if not directory:
return web.Response(status=400)
return web.Response(status=400, text="Directory not provided")
path = self.get_request_user_filepath(request, directory)
if not path:
return web.Response(status=403)
return web.Response(status=403, text="Invalid directory")
if not os.path.exists(path):
return web.Response(status=404)
return web.Response(status=404, text="Directory not found")
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
results = glob.glob(os.path.join(
glob.escape(path), '**/*'), recursive=recurse)
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
# Use different patterns based on whether we're recursing or not
if recurse:
pattern = os.path.join(glob.escape(path), '**', '*')
else:
pattern = os.path.join(glob.escape(path), '*')
results = glob.glob(pattern, recursive=recurse)
if full_info:
results = [
{
'path': os.path.relpath(x, path).replace(os.sep, '/'),
'size': os.path.getsize(x),
'modified': os.path.getmtime(x)
} for x in results if os.path.isfile(x)
]
else:
results = [
os.path.relpath(x, path).replace(os.sep, '/')
for x in results
if os.path.isfile(x)
]
split_path = request.rel_url.query.get('split', '').lower() == "true"
if split_path:
results = [[x] + x.split(os.sep) for x in results]
if split_path and not full_info:
results = [[x] + x.split('/') for x in results]
return web.json_response(results)
@@ -138,14 +188,14 @@ class UserManager():
file = request.match_info.get(param, None)
if not file:
return web.Response(status=400)
path = self.get_request_user_filepath(request, file)
if not path:
return web.Response(status=403)
if check_exists and not os.path.exists(path):
return web.Response(status=404)
return path
@routes.get("/userdata/{file}")
@@ -153,7 +203,7 @@ class UserManager():
path = get_user_data_path(request, check_exists=True)
if not isinstance(path, str):
return path
return web.FileResponse(path)
@routes.post("/userdata/{file}")
@@ -161,7 +211,7 @@ class UserManager():
path = get_user_data_path(request)
if not isinstance(path, str):
return path
overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(path):
return web.Response(status=409)
@@ -170,7 +220,7 @@ class UserManager():
with open(path, "wb") as f:
f.write(body)
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
return web.json_response(resp)
@@ -181,7 +231,7 @@ class UserManager():
return path
os.remove(path)
return web.Response(status=204)
@routes.post("/userdata/{file}/move/{dest}")
@@ -189,17 +239,17 @@ class UserManager():
source = get_user_data_path(request, check_exists=True)
if not isinstance(source, str):
return source
dest = get_user_data_path(request, check_exists=False, param="dest")
if not isinstance(source, str):
return dest
overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(dest):
return web.Response(status=409)
print(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest)
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
return web.json_response(resp)

View File

@@ -6,6 +6,7 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
def __init__(
self,
num_blocks = None,
control_latent_channels = None,
dtype = None,
device = None,
operations = None,
@@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
for _ in range(len(self.joint_blocks)):
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
if control_latent_channels is None:
control_latent_channels = self.in_channels
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
None,
self.patch_size,
self.in_channels,
control_latent_channels,
self.hidden_size,
bias=True,
strict_img_size=False,

View File

@@ -36,7 +36,7 @@ class EnumAction(argparse.Action):
parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
@@ -92,6 +92,8 @@ class LatentPreviewMethod(enum.Enum):
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
@@ -134,7 +136,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
@@ -169,6 +171,8 @@ parser.add_argument(
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
)
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
if comfy.options.args_parsing:
args = parser.parse_args()
else:

View File

@@ -109,8 +109,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
keys = list(sd.keys())
for k in keys:
if k not in u:
t = sd.pop(k)
del t
sd.pop(k)
return clip
def load(ckpt_path):

View File

@@ -79,13 +79,21 @@ class ControlBase:
self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
self.extra_concat_orig = []
self.extra_concat = None
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range
if self.latent_format is not None:
if vae is None:
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
self.vae = vae
self.extra_concat_orig = extra_concat.copy()
if self.concat_mask and len(self.extra_concat_orig) == 0:
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
return self
def pre_run(self, model, percent_to_timestep_function):
@@ -100,9 +108,9 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
def get_models(self):
@@ -123,6 +131,8 @@ class ControlBase:
c.vae = self.vae
c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type
c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy()
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
@@ -175,7 +185,7 @@ class ControlBase:
class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
super().__init__(device)
self.control_model = control_model
self.load_device = load_device
@@ -189,6 +199,7 @@ class ControlNet(ControlBase):
self.latent_format = latent_format
self.extra_conds += extra_conds
self.strength_type = strength_type
self.concat_mask = concat_mask
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
@@ -213,6 +224,9 @@ class ControlNet(ControlBase):
compression_ratio = self.compression_ratio
if self.vae is not None:
compression_ratio *= self.vae.downscale_ratio
else:
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
@@ -220,6 +234,13 @@ class ControlNet(ControlBase):
comfy.model_management.load_models_gpu(loaded_models)
if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint)
if len(self.extra_concat_orig) > 0:
to_concat = []
for c in self.extra_concat_orig:
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
@@ -319,7 +340,7 @@ class ControlLoraOps:
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None):
def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options
ControlBase.__init__(self, device)
self.control_weights = control_weights
self.global_average_pooling = global_average_pooling
@@ -376,19 +397,25 @@ class ControlLora(ControlNet):
def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def controlnet_config(sd):
def controlnet_config(sd, model_options={}):
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
supported_inference_dtypes = model_config.supported_inference_dtypes
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(sd)
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
controlnet_config = model_config.unet_config
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
offload_device = comfy.model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
@@ -403,24 +430,29 @@ def controlnet_load_state_dict(control_model, sd):
logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model
def load_controlnet_mmdit(sd):
def load_controlnet_mmdit(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
concat_mask = False
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
if control_latent_channels == 17: #inpaint controlnet
concat_mask = True
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data)
@@ -430,17 +462,17 @@ def load_controlnet_hunyuandit(controlnet_data):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control
def load_controlnet_flux_xlabs(sd):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet_flux_instantx(sd):
def load_controlnet_flux_instantx(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
for k in sd:
new_sd[k] = sd[k]
@@ -449,21 +481,30 @@ def load_controlnet_flux_instantx(sd):
if union_cnet in new_sd:
num_union_modes = new_sd[union_cnet].shape[0]
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
concat_mask = False
if control_latent_channels == 17:
concat_mask = True
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.Flux()
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
controlnet_data = state_dict
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data)
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)
return ControlLora(controlnet_data, model_options=model_options)
controlnet_config = None
supported_inference_dtypes = None
@@ -518,13 +559,15 @@ def load_controlnet(ckpt_path, model=None):
if len(leftover_keys) > 0:
logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
elif "controlnet_blocks.0.weight" in controlnet_data:
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs(controlnet_data)
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data)
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data)
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
pth_key = 'control_model.zero_convs.0.0.weight'
pth = False
@@ -536,25 +579,36 @@ def load_controlnet(ckpt_path, model=None):
elif key in controlnet_data:
prefix = ""
else:
net = load_t2i_adapter(controlnet_data)
net = load_t2i_adapter(controlnet_data, model_options=model_options)
if net is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
logging.error("error could not detect control model type.")
return net
if controlnet_config is None:
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
supported_inference_dtypes = model_config.supported_inference_dtypes
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
controlnet_config = model_config.unet_config
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
if supported_inference_dtypes is None:
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
if supported_inference_dtypes is None:
unet_dtype = comfy.model_management.unet_dtype()
else:
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
controlnet_config["operations"] = operations
controlnet_config["dtype"] = unet_dtype
controlnet_config["device"] = comfy.model_management.unet_offload_device()
controlnet_config.pop("out_channels")
@@ -590,14 +644,21 @@ def load_controlnet(ckpt_path, model=None):
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
global_average_pooling = model_options.get("global_average_pooling", False)
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
def load_controlnet(ckpt_path, model=None, model_options={}):
if "global_average_pooling" not in model_options:
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
model_options["global_average_pooling"] = True
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
if cnet is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return cnet
class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
super().__init__(device)
@@ -653,7 +714,7 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
def load_t2i_adapter(t2i_data):
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8
upscale_algorithm = 'nearest-exact'

View File

@@ -1,5 +1,4 @@
import torch
import math
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
mantissa_scaled = torch.where(
@@ -41,9 +40,8 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None):
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
)
del abs_x
return sign.to(dtype=dtype)
return sign
@@ -57,6 +55,11 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
return output
return value.to(dtype=dtype)

View File

@@ -44,6 +44,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
return append_zero(sigmas)
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
epsilon = 1e-5 # avoid log(0)
x = torch.linspace(0, 1, n, device=device)
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
sigmas = clamp(torch.exp(lmb))
return sigmas
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / utils.append_dims(sigma, x.ndim)
@@ -1101,3 +1112,78 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigma_down == 0:
# Euler method
d = to_d(x, sigmas[i], temp[0])
dt = sigma_down - sigmas[i]
x = denoised + d * sigma_down
else:
# DPM-Solver++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
r = 1 / 2
h = t_next - t
s = t + r * h
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
# Noise addition
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
t_fn = lambda sigma: sigma.log().neg()
old_uncond_denoised = None
uncond_denoised = None
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_uncond_denoised is None or sigmas[i + 1] == 0:
denoised_mix = -torch.exp(-h) * uncond_denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
x = denoised + denoised_mix + torch.exp(-h) * x
old_uncond_denoised = uncond_denoised
return x

View File

@@ -4,6 +4,7 @@ class LatentFormat:
scale_factor = 1.0
latent_channels = 4
latent_rgb_factors = None
latent_rgb_factors_bias = None
taesd_decoder_name = None
def process_in(self, latent):
@@ -30,11 +31,13 @@ class SDXL(LatentFormat):
def __init__(self):
self.latent_rgb_factors = [
# R G B
[ 0.3920, 0.4054, 0.4549],
[-0.2634, -0.0196, 0.0653],
[ 0.0568, 0.1687, -0.0755],
[-0.3112, -0.2359, -0.2076]
[ 0.3651, 0.4232, 0.4341],
[-0.2533, -0.0042, 0.1068],
[ 0.1076, 0.1111, -0.0362],
[-0.3165, -0.2492, -0.2188]
]
self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011]
self.taesd_decoder_name = "taesdxl_decoder"
class SDXL_Playground_2_5(LatentFormat):
@@ -112,23 +115,24 @@ class SD3(LatentFormat):
self.scale_factor = 1.5305
self.shift_factor = 0.0609
self.latent_rgb_factors = [
[-0.0645, 0.0177, 0.1052],
[ 0.0028, 0.0312, 0.0650],
[ 0.1848, 0.0762, 0.0360],
[ 0.0944, 0.0360, 0.0889],
[ 0.0897, 0.0506, -0.0364],
[-0.0020, 0.1203, 0.0284],
[ 0.0855, 0.0118, 0.0283],
[-0.0539, 0.0658, 0.1047],
[-0.0057, 0.0116, 0.0700],
[-0.0412, 0.0281, -0.0039],
[ 0.1106, 0.1171, 0.1220],
[-0.0248, 0.0682, -0.0481],
[ 0.0815, 0.0846, 0.1207],
[-0.0120, -0.0055, -0.0867],
[-0.0749, -0.0634, -0.0456],
[-0.1418, -0.1457, -0.1259]
[-0.0922, -0.0175, 0.0749],
[ 0.0311, 0.0633, 0.0954],
[ 0.1994, 0.0927, 0.0458],
[ 0.0856, 0.0339, 0.0902],
[ 0.0587, 0.0272, -0.0496],
[-0.0006, 0.1104, 0.0309],
[ 0.0978, 0.0306, 0.0427],
[-0.0042, 0.1038, 0.1358],
[-0.0194, 0.0020, 0.0669],
[-0.0488, 0.0130, -0.0268],
[ 0.0922, 0.0988, 0.0951],
[-0.0278, 0.0524, -0.0542],
[ 0.0332, 0.0456, 0.0895],
[-0.0069, -0.0030, -0.0810],
[-0.0596, -0.0465, -0.0293],
[-0.1448, -0.1463, -0.1189]
]
self.latent_rgb_factors_bias = [0.2394, 0.2135, 0.1925]
self.taesd_decoder_name = "taesd3_decoder"
def process_in(self, latent):
@@ -146,23 +150,24 @@ class Flux(SD3):
self.scale_factor = 0.3611
self.shift_factor = 0.1159
self.latent_rgb_factors =[
[-0.0404, 0.0159, 0.0609],
[ 0.0043, 0.0298, 0.0850],
[ 0.0328, -0.0749, -0.0503],
[-0.0245, 0.0085, 0.0549],
[ 0.0966, 0.0894, 0.0530],
[ 0.0035, 0.0399, 0.0123],
[ 0.0583, 0.1184, 0.1262],
[-0.0191, -0.0206, -0.0306],
[-0.0324, 0.0055, 0.1001],
[ 0.0955, 0.0659, -0.0545],
[-0.0504, 0.0231, -0.0013],
[ 0.0500, -0.0008, -0.0088],
[ 0.0982, 0.0941, 0.0976],
[-0.1233, -0.0280, -0.0897],
[-0.0005, -0.0530, -0.0020],
[-0.1273, -0.0932, -0.0680]
[-0.0346, 0.0244, 0.0681],
[ 0.0034, 0.0210, 0.0687],
[ 0.0275, -0.0668, -0.0433],
[-0.0174, 0.0160, 0.0617],
[ 0.0859, 0.0721, 0.0329],
[ 0.0004, 0.0383, 0.0115],
[ 0.0405, 0.0861, 0.0915],
[-0.0236, -0.0185, -0.0259],
[-0.0245, 0.0250, 0.1180],
[ 0.1008, 0.0755, -0.0421],
[-0.0515, 0.0201, 0.0011],
[ 0.0428, -0.0012, -0.0036],
[ 0.0817, 0.0765, 0.0749],
[-0.1264, -0.0522, -0.1103],
[-0.0280, -0.0881, -0.0499],
[-0.1262, -0.0982, -0.0778]
]
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.taesd_decoder_name = "taef1_decoder"
def process_in(self, latent):

View File

@@ -14,7 +14,7 @@ except:
rms_norm_torch = None
def rms_norm(x, weight, eps=1e-6):
if rms_norm_torch is not None:
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)

View File

@@ -1,4 +1,5 @@
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
#modified to support different types of flux controlnets
import torch
import math
@@ -12,22 +13,65 @@ from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
from .model import Flux
import comfy.ldm.common_dit
class MistolineCondDownsamplBlock(nn.Module):
def __init__(self, dtype=None, device=None, operations=None):
super().__init__()
self.encoder = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward(self, x):
return self.encoder(x)
class MistolineControlnetBlock(nn.Module):
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.linear(x))
class ControlNetFlux(Flux):
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.main_model_double = 19
self.main_model_single = 38
self.mistoline = mistoline
# add ControlNet blocks
if self.mistoline:
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks.append(controlnet_block)
self.controlnet_blocks.append(control_block())
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(self.params.depth_single_blocks):
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
self.controlnet_single_blocks.append(control_block())
self.num_union_modes = num_union_modes
self.controlnet_mode_embedder = None
@@ -36,25 +80,33 @@ class ControlNetFlux(Flux):
self.gradient_checkpointing = False
self.latent_input = latent_input
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
if control_latent_channels is None:
control_latent_channels = self.in_channels
else:
control_latent_channels *= 2 * 2 #patch size
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
if not self.latent_input:
self.input_hint_block = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
if self.mistoline:
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
else:
self.input_hint_block = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward_orig(
self,
@@ -73,9 +125,6 @@ class ControlNetFlux(Flux):
# running on sequences img
img = self.img_in(img)
if not self.latent_input:
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
@@ -131,9 +180,14 @@ class ControlNetFlux(Flux):
patch_size = 2
if self.latent_input:
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
elif self.mistoline:
hint = hint * 2.0 - 1.0
hint = self.input_cond_block(hint)
else:
hint = hint * 2.0 - 1.0
hint = self.input_hint_block(hint)
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))

View File

@@ -108,7 +108,7 @@ class Flux(nn.Module):
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y)
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
@@ -151,8 +151,8 @@ class Flux(nn.Module):
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)

View File

@@ -842,6 +842,11 @@ class UNetModel(nn.Module):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if "emb_patch" in transformer_patches:
patch = transformer_patches["emb_patch"]
for p in patch:
emb = p(emb, self.model_channels, transformer_options)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)

View File

@@ -201,9 +201,13 @@ def load_lora(lora, to_load):
def model_lora_keys_clip(model, key_map={}):
sdk = model.state_dict().keys()
for k in sdk:
if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False
clip_g_present = False
for b in range(32): #TODO: clean up
for c in LORA_CLIP_MAP:
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
@@ -227,6 +231,7 @@ def model_lora_keys_clip(model, key_map={}):
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
clip_g_present = True
if clip_l_present:
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k
@@ -242,10 +247,18 @@ def model_lora_keys_clip(model, key_map={}):
for k in sdk:
if k.endswith(".weight"):
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
l_key = k[len("t5xxl.transformer."):-len(".weight")]
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
t5_index = 1
if clip_g_present:
t5_index += 1
if clip_l_present:
t5_index += 1
if t5_index == 2:
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
t5_index += 1
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
@@ -281,6 +294,7 @@ def model_lora_keys_unet(model, key_map={}):
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
key_lora = k[:-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = unet_key
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
diffusers_lora_prefix = ["", "unet."]
for p in diffusers_lora_prefix:
@@ -324,14 +338,15 @@ def model_lora_keys_unet(model, key_map={}):
to = diffusers_keys[k]
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
return key_map
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_calc = weight + function(lora_diff).type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
@@ -400,7 +415,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
weight *= strength_model
if isinstance(v, list):
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
v = (calculate_weight(v[1:], comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype, copy=True), key, intermediate_dtype=intermediate_dtype), )
if len(v) == 1:
patch_type = "diff"
@@ -438,7 +453,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
@@ -484,7 +499,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
@@ -521,28 +536,48 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
dora_scale = v[5]
old_glora = False
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
rank = v[0].shape[0]
old_glora = True
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
pass
else:
old_glora = False
rank = v[1].shape[0]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
if v[4] is not None:
alpha = v[4] / rank
else:
alpha = 1.0
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape)
if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
else:
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:

View File

@@ -96,7 +96,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=model_config.optimizations.get("fp8", False))
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)

View File

@@ -145,7 +145,7 @@ total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try:
logging.info("pytorch version: {}".format(torch.version.__version__))
logging.info("pytorch version: {}".format(torch_version))
except:
pass
@@ -326,7 +326,7 @@ class LoadedModel:
self.model_unload()
raise e
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
with torch.no_grad():
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
@@ -426,7 +426,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
for x in sorted(can_unload):
@@ -626,6 +626,8 @@ def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if model_params < 0:
model_params = 1000000000000000000000
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:
@@ -897,7 +899,7 @@ def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
try:
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
if (14, 5) <= macos_version <= (15, 0, 1): # black image bug on recent versions of macOS
upcast = True
except:
pass
@@ -1063,6 +1065,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
def supports_fp8_compute(device=None):
if not is_nvidia():
return False
props = torch.cuda.get_device_properties(device)
if props.major >= 9:
return True
@@ -1070,6 +1075,14 @@ def supports_fp8_compute(device=None):
return False
if props.minor < 9:
return False
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
return False
if WINDOWS:
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
return False
return True
def soft_empty_cache(force=False):

View File

@@ -28,7 +28,7 @@ import comfy.utils
import comfy.float
import comfy.model_management
import comfy.lora
from comfy.types import UnetWrapperFunction
from comfy.comfy_types import UnetWrapperFunction
def string_to_seed(data):
crc = 0xFFFFFFFF
@@ -88,8 +88,12 @@ class LowVramPatch:
self.key = key
self.patches = patches
def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
intermediate_dtype = weight.dtype
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
@@ -283,17 +287,21 @@ class ModelPatcher:
return list(p)
def get_key_patches(self, filter_prefix=None):
comfy.model_management.unload_model_clones(self)
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
if filter_prefix is not None:
if not k.startswith(filter_prefix):
continue
if k in self.patches:
p[k] = [model_sd[k]] + self.patches[k]
bk = self.backup.get(k, None)
if bk is not None:
weight = bk.weight
else:
p[k] = (model_sd[k],)
weight = model_sd[k]
if k in self.patches:
p[k] = [weight] + self.patches[k]
else:
p[k] = (weight,)
return p
def model_state_dict(self, filter_prefix=None):

View File

@@ -260,7 +260,6 @@ def fp8_linear(self, input):
if len(input.shape) == 3:
inn = input.reshape(-1, input.shape[2]).to(dtype)
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t()
@@ -300,10 +299,14 @@ class fp8_ops(manual_cast):
return torch.nn.functional.linear(input, weight, bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None):
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False):
if comfy.model_management.supports_fp8_compute(load_device):
if (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
if args.fast:
if args.fast and not disable_fast_fp8:
if comfy.model_management.supports_fp8_compute(load_device):
return fp8_ops
return manual_cast

View File

@@ -6,7 +6,7 @@ from comfy import model_management
import math
import logging
import comfy.sampler_helpers
import scipy
import scipy.stats
import numpy
def get_area_and_mult(conds, x_in, timestep_in):
@@ -570,8 +570,8 @@ class Sampler:
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis"]
class KSAMPLER(Sampler):

View File

@@ -29,7 +29,6 @@ import comfy.text_encoders.long_clipl
import comfy.model_patcher
import comfy.lora
import comfy.t2i_adapter.adapter
import comfy.supported_models_base
import comfy.taesd.taesd
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
@@ -70,14 +69,14 @@ class CLIP:
clip = target.clip
tokenizer = target.tokenizer
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
load_device = model_options.get("load_device", model_management.text_encoder_device())
offload_device = model_options.get("offload_device", model_management.text_encoder_offload_device())
dtype = model_options.get("dtype", None)
if dtype is None:
dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
params['device'] = model_options.get("initial_device", model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)))
params['model_options'] = model_options
self.cond_stage_model = clip(**(params))
@@ -348,7 +347,7 @@ class VAE:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number):
@@ -406,8 +405,35 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
class TEModel(Enum):
CLIP_L = 1
CLIP_H = 2
CLIP_G = 3
T5_XXL = 4
T5_XL = 5
T5_BASE = 6
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
return TEModel.CLIP_G
if "text_model.encoder.layers.22.mlp.fc1.weight" in sd:
return TEModel.CLIP_H
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
return TEModel.CLIP_L
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
if weight.shape[-1] == 4096:
return TEModel.T5_XXL
elif weight.shape[-1] == 2048:
return TEModel.T5_XL
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
return TEModel.T5_BASE
return None
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = state_dicts
class EmptyClass:
pass
@@ -421,39 +447,42 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target = EmptyClass()
clip_target.params = {}
if len(clip_data) == 1:
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
te_model = detect_te_model(clip_data[0])
if te_model == TEModel.CLIP_G:
if clip_type == CLIPType.STABLE_CASCADE:
clip_target.clip = sdxl_clip.StableCascadeClipModel
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
elif clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
elif te_model == TEModel.CLIP_H:
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
elif te_model == TEModel.T5_XXL:
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
dtype_t5 = weight.dtype
if weight.shape[-1] == 4096:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif weight.shape[-1] == 2048:
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif te_model == TEModel.T5_XL:
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
elif te_model == TEModel.T5_BASE:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
else:
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel
clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2:
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HUNYUAN_DIT:
clip_target.clip = comfy.text_encoders.hydit.HyditModel
@@ -475,10 +504,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
parameters = 0
tokenizer_data = {}
for c in clip_data:
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
@@ -562,7 +593,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model.load_model_weights(sd, diffusion_model_prefix)
@@ -647,7 +677,10 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", None)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if model_options.get("fp8_optimizations", False):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")

View File

@@ -542,6 +542,7 @@ class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -570,6 +571,7 @@ class SD1ClipModel(torch.nn.Module):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
self.dtypes = set()

View File

@@ -22,7 +22,8 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
class SDXLTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -40,7 +41,8 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set([dtype])
@@ -57,7 +59,8 @@ class SDXLClipModel(torch.nn.Module):
token_weight_pairs_l = token_weight_pairs["l"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return torch.cat([l_out, g_out], dim=-1), g_pooled
cut_to = min(l_out.shape[1], g_out.shape[1])
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:

View File

@@ -49,6 +49,7 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
optimizations = {"fp8": False}
@classmethod
def matches(s, unet_config, state_dict=None):

View File

@@ -13,12 +13,13 @@ class T5XXLModel(sd1_clip.SDClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -38,7 +39,8 @@ class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5])

View File

@@ -6,9 +6,9 @@ class LongClipTokenizer_(sd1_clip.SDTokenizer):
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
class LongClipModel_(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
def __init__(self, *args, **kwargs):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -17,3 +17,14 @@ class LongClipTokenizer(sd1_clip.SD1Tokenizer):
class LongClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
def model_options_long_clip(sd, tokenizer_data, model_options):
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
if w is None:
w = sd.get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
tokenizer_data = tokenizer_data.copy()
model_options = model_options.copy()
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
model_options["clip_l_class"] = LongClipModel_
return tokenizer_data, model_options

View File

@@ -20,7 +20,8 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
@@ -42,7 +43,8 @@ class SD3ClipModel(torch.nn.Module):
super().__init__()
self.dtypes = set()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_l = None
@@ -95,7 +97,8 @@ class SD3ClipModel(torch.nn.Module):
if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
if lg_out is not None:
lg_out = torch.cat([lg_out, g_out], dim=-1)
cut_to = min(lg_out.shape[1], g_out.shape[1])
lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
else:
lg_out = torch.nn.functional.pad(g_out, (768, 0))
else:

View File

@@ -713,7 +713,9 @@ def common_upscale(samples, width, height, upscale_method, crop):
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
return rows * cols
@torch.inference_mode()
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
@@ -722,10 +724,20 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
for b in range(samples.shape[0]):
s = samples[b:b+1]
# handle entire input fitting in a single tile
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
output[b:b+1] = function(s).to(output_device)
if pbar is not None:
pbar.update(1)
continue
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
positions = [range(0, s.shape[d+2], tile[d] - overlap) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
for it in itertools.product(*positions):
s_in = s
upscaled = []
@@ -734,15 +746,16 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(pos * upscale_amount))
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount)
for t in range(feather):
for d in range(2, dims + 2):
m = mask.narrow(d, t, 1)
m *= ((1.0/feather) * (t + 1))
m = mask.narrow(d, mask.shape[d] -1 -t, 1)
m *= ((1.0/feather) * (t + 1))
a = (t + 1) / feather
mask.narrow(d, t, 1).mul_(a)
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
o = out
o_d = out_div
@@ -750,8 +763,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o += ps * mask
o_d += mask
o.add_(ps * mask)
o_d.add_(mask)
if pbar is not None:
pbar.update(1)

View File

@@ -1,11 +1,21 @@
import itertools
from typing import Sequence, Mapping
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
import nodes
from comfy_execution.graph_utils import is_link
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
def include_unique_id_in_input(class_type: str) -> bool:
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet:
def __init__(self, dynprompt, node_ids, is_changed_cache):
self.keys = {}
@@ -98,7 +108,7 @@ class CacheKeySetInputSignature(CacheKeySet):
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
inputs = node["inputs"]
for key in sorted(inputs.keys()):

View File

@@ -99,30 +99,44 @@ class TopologicalSort:
self.add_strong_link(from_node_id, from_socket, to_node_id)
def add_strong_link(self, from_node_id, from_socket, to_node_id):
self.add_node(from_node_id)
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True
if not self.is_cached(from_node_id):
self.add_node(from_node_id)
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
if unique_id in self.pendingNodes:
return
self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
node_ids = [node_unique_id]
links = []
inputs = self.dynprompt.get_node(unique_id)["inputs"]
for input_name in inputs:
value = inputs[input_name]
if is_link(value):
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if include_lazy or not is_lazy:
self.add_strong_link(from_node_id, from_socket, unique_id)
while len(node_ids) > 0:
unique_id = node_ids.pop()
if unique_id in self.pendingNodes:
continue
self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}
inputs = self.dynprompt.get_node(unique_id)["inputs"]
for input_name in inputs:
value = inputs[input_name]
if is_link(value):
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id))
for link in links:
self.add_strong_link(*link)
def is_cached(self, node_id):
return False
def get_ready_nodes(self):
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
@@ -146,11 +160,8 @@ class ExecutionList(TopologicalSort):
self.output_cache = output_cache
self.staged_node_id = None
def add_strong_link(self, from_node_id, from_socket, to_node_id):
if self.output_cache.get(from_node_id) is not None:
# Nothing to do
return
super().add_strong_link(from_node_id, from_socket, to_node_id)
def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None
def stage_node_execution(self):
assert self.staged_node_id is None

View File

@@ -16,14 +16,15 @@ class EmptyLatentAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}}
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/audio"
def generate(self, seconds):
batch_size = 1
def generate(self, seconds, batch_size):
length = round((seconds * 44100 / 2048) / 2) * 2
latent = torch.zeros([batch_size, 64, length], device=self.device)
return ({"samples":latent, "type": "audio"}, )
@@ -58,6 +59,9 @@ class VAEDecodeAudio:
def decode(self, vae, samples):
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
return ({"waveform": audio, "sample_rate": 44100}, )
@@ -183,17 +187,10 @@ class PreviewAudio(SaveAudio):
}
class LoadAudio:
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [
f for f in os.listdir(input_dir)
if (os.path.isfile(os.path.join(input_dir, f))
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
)
]
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
CATEGORY = "audio"

View File

@@ -1,4 +1,6 @@
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
import nodes
import comfy.utils
class SetUnionControlNetType:
@classmethod
@@ -22,6 +24,37 @@ class SetUnionControlNetType:
return (control_net,)
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"vae": ("VAE", ),
"image": ("IMAGE", ),
"mask": ("MASK", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
FUNCTION = "apply_inpaint_controlnet"
CATEGORY = "conditioning/controlnet"
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
extra_concat = []
if control_net.concat_mask:
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
extra_concat = [mask]
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
NODE_CLASS_MAPPINGS = {
"SetUnionControlNetType": SetUnionControlNetType,
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
}

View File

@@ -90,6 +90,27 @@ class PolyexponentialScheduler:
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
return (sigmas, )
class LaplaceScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
"beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
return (sigmas, )
class SDTurboScheduler:
@classmethod
def INPUT_TYPES(s):
@@ -673,6 +694,7 @@ NODE_CLASS_MAPPINGS = {
"KarrasScheduler": KarrasScheduler,
"ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler,
"LaplaceScheduler": LaplaceScheduler,
"VPScheduler": VPScheduler,
"BetaSamplingScheduler": BetaSamplingScheduler,
"SDTurboScheduler": SDTurboScheduler,

View File

@@ -107,7 +107,7 @@ class HypernetworkLoader:
CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None:

View File

@@ -0,0 +1,115 @@
import torch
import comfy.model_management
import comfy.utils
import folder_paths
import os
import logging
from enum import Enum
CLAMP_QUANTILE = 0.99
def extract_lora(diff, rank):
conv2d = (len(diff.shape) == 4)
kernel_size = None if not conv2d else diff.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = diff.size()[0:2]
rank = min(rank, in_dim, out_dim)
if conv2d:
if conv2d_3x3:
diff = diff.flatten(start_dim=1)
else:
diff = diff.squeeze()
U, S, Vh = torch.linalg.svd(diff.float())
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
return (U, Vh)
class LORAType(Enum):
STANDARD = 0
FULL_DIFF = 1
LORA_TYPES = {"standard": LORAType.STANDARD,
"full_diff": LORAType.FULL_DIFF}
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
for k in sd:
if k.endswith(".weight"):
weight_diff = sd[k]
if lora_type == LORAType.STANDARD:
if weight_diff.ndim < 2:
if bias_diff:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
continue
try:
out = extract_lora(weight_diff, rank)
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
except:
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
elif lora_type == LORAType.FULL_DIFF:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
elif bias_diff and k.endswith(".bias"):
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
return output_sd
class LoraSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
"lora_type": (tuple(LORA_TYPES.keys()),),
"bias_diff": ("BOOLEAN", {"default": True}),
},
"optional": {"model_diff": ("MODEL",),
"text_encoder_diff": ("CLIP",)},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
if model_diff is None and text_encoder_diff is None:
return {}
lora_type = LORA_TYPES.get(lora_type)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
output_sd = {}
if model_diff is not None:
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
if text_encoder_diff is not None:
output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
return {}
NODE_CLASS_MAPPINGS = {
"LoraSave": LoraSave
}

View File

@@ -17,7 +17,7 @@ class PatchModelAddDownscale:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
CATEGORY = "model_patches/unet"
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
model_sampling = model.get_model_object("model_sampling")

View File

@@ -26,6 +26,7 @@ class PerpNeg:
FUNCTION = "patch"
CATEGORY = "_for_testing"
DEPRECATED = True
def patch(self, model, empty_conditioning, neg_scale):
m = model.clone()

View File

@@ -126,7 +126,7 @@ class PhotoMakerLoader:
CATEGORY = "_for_testing/photomaker"
def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder()
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data:

View File

@@ -15,9 +15,9 @@ class TripleCLIPLoader:
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip_path3 = folder_paths.get_full_path("clip", clip_name3)
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("clip", clip_name3)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
@@ -36,7 +36,7 @@ class EmptySD3LatentImage:
CATEGORY = "latent/sd3"
def generate(self, width, height, batch_size=1):
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
class CLIPTextEncodeSD3:
@@ -93,6 +93,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
CATEGORY = "conditioning/controlnet"
DEPRECATED = True
NODE_CLASS_MAPPINGS = {
"TripleCLIPLoader": TripleCLIPLoader,
@@ -103,5 +104,5 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
"ControlNetApplySD3": "Apply Controlnet with VAE",
}

View File

@@ -116,6 +116,7 @@ class StableCascade_SuperResolutionControlnet:
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
FUNCTION = "generate"
EXPERIMENTAL = True
CATEGORY = "_for_testing/stable_cascade"
def generate(self, image, vae):

View File

@@ -154,7 +154,7 @@ class TomePatchModel:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
CATEGORY = "model_patches/unet"
def patch(self, model, ratio):
self.u = None

View File

@@ -0,0 +1,22 @@
import torch
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"backend": (["inductor", "cudagraphs"],),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model, backend):
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
return (m, )
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}

View File

@@ -25,7 +25,7 @@ class UpscaleModelLoader:
CATEGORY = "loaders"
def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name)
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})

View File

@@ -17,7 +17,7 @@ class ImageOnlyCheckpointLoader:
CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2])
@@ -107,7 +107,7 @@ class VideoTriangleCFGGuidance:
return (m, )
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
CATEGORY = "_for_testing"
CATEGORY = "advanced/model_merging"
@classmethod
def INPUT_TYPES(s):

View File

@@ -37,6 +37,7 @@ class SaveImageWebsocket:
return {}
@classmethod
def IS_CHANGED(s, images):
return time.time()

View File

@@ -179,7 +179,13 @@ def merge_result_data(results, obj):
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
value = []
for o in results:
if isinstance(o[i], ExecutionBlocker):
value.append(o[i])
else:
value.extend(o[i])
output.append(value)
else:
output.append([o[i] for o in results])
return output

View File

@@ -25,11 +25,16 @@ a111:
#comfyui:
# base_path: path/to/comfyui/
# # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
# #is_default: true
# checkpoints: models/checkpoints/
# clip: models/clip/
# clip_vision: models/clip_vision/
# configs: models/configs/
# controlnet: models/controlnet/
# diffusion_models: |
# models/diffusion_models
# models/unet
# embeddings: models/embeddings/
# loras: models/loras/
# upscale_models: models/upscale_models/

View File

@@ -2,7 +2,9 @@ from __future__ import annotations
import os
import time
import mimetypes
import logging
from typing import Set, List, Dict, Tuple, Literal
from collections.abc import Collection
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
@@ -44,6 +46,40 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
class CacheHelper:
"""
Helper class for managing file list cache data.
"""
def __init__(self):
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
self.active = False
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
if not self.active:
return default
return self.cache.get(key, default)
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
if self.active:
self.cache[key] = value
def clear(self):
self.cache.clear()
def __enter__(self):
self.active = True
return self
def __exit__(self, exc_type, exc_value, traceback):
self.active = False
self.clear()
cache_helper = CacheHelper()
extension_mimetypes_cache = {
"webp" : "image",
}
def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"}
return legacy.get(folder_name, folder_name)
@@ -78,6 +114,13 @@ def get_input_directory() -> str:
global input_directory
return input_directory
def get_user_directory() -> str:
return user_directory
def set_user_directory(user_dir: str) -> None:
global user_directory
user_directory = user_dir
#NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name: str) -> str | None:
@@ -89,6 +132,28 @@ def get_directory_by_type(type_name: str) -> str | None:
return get_input_directory()
return None
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
"""
Example:
files = os.listdir(folder_paths.get_input_directory())
filter_files_content_types(files, ["image", "audio", "video"])
"""
global extension_mimetypes_cache
result = []
for file in files:
extension = file.split('.')[-1]
if extension not in extension_mimetypes_cache:
mime_type, _ = mimetypes.guess_type(file, strict=False)
if not mime_type:
continue
content_type = mime_type.split('/')[0]
extension_mimetypes_cache[extension] = content_type
else:
content_type = extension_mimetypes_cache[extension]
if content_type in content_types:
result.append(file)
return result
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir
@@ -130,11 +195,14 @@ def exists_annotated_filepath(name) -> bool:
return os.path.exists(filepath)
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
def add_model_folder_path(folder_name: str, full_folder_path: str, is_default: bool = False) -> None:
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name in folder_names_and_paths:
folder_names_and_paths[folder_name][0].append(full_folder_path)
if is_default:
folder_names_and_paths[folder_name][0].insert(0, full_folder_path)
else:
folder_names_and_paths[folder_name][0].append(full_folder_path)
else:
folder_names_and_paths[folder_name] = ([full_folder_path], set())
@@ -166,8 +234,12 @@ def recursive_search(directory: str, excluded_dir_names: list[str] | None=None)
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
continue
for d in subdirs:
path: str = os.path.join(dirpath, d)
@@ -200,6 +272,14 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
return None
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
full_path = get_full_path(folder_name, filename)
if full_path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
return full_path
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths
@@ -214,6 +294,10 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
return sorted(list(output_list)), output_folders, time.perf_counter()
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
strong_cache = cache_helper.get(folder_name)
if strong_cache is not None:
return strong_cache
global filename_list_cache
global folder_names_and_paths
folder_name = map_legacy(folder_name)
@@ -242,6 +326,7 @@ def get_filename_list(folder_name: str) -> list[str]:
out = get_filename_list_(folder_name)
global filename_list_cache
filename_list_cache[folder_name] = out
cache_helper.set(folder_name, out)
return list(out[0])
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
@@ -257,9 +342,17 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
def compute_vars(input: str, image_width: int, image_height: int) -> str:
input = input.replace("%width%", str(image_width))
input = input.replace("%height%", str(image_height))
now = time.localtime()
input = input.replace("%year%", str(now.tm_year))
input = input.replace("%month%", str(now.tm_mon).zfill(2))
input = input.replace("%day%", str(now.tm_mday).zfill(2))
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
input = input.replace("%minute%", str(now.tm_min).zfill(2))
input = input.replace("%second%", str(now.tm_sec).zfill(2))
return input
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
if "%" in filename_prefix:
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix))

View File

@@ -9,7 +9,7 @@ import folder_paths
import comfy.utils
import logging
MAX_PREVIEW_RESOLUTION = 512
MAX_PREVIEW_RESOLUTION = args.preview_size
def preview_to_image(latent_image):
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
@@ -36,12 +36,20 @@ class TAESDPreviewerImpl(LatentPreviewer):
class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self, latent_rgb_factors):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
self.latent_rgb_factors_bias = None
if latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
def decode_latent_to_preview(self, x0):
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
return preview_to_image(latent_image)
@@ -71,7 +79,7 @@ def get_previewer(device, latent_format):
if previewer is None:
if latent_format.latent_rgb_factors is not None:
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
return previewer
def prepare_callback(model, steps, x0_output_dict=None):

43
main.py
View File

@@ -9,7 +9,7 @@ from comfy.cli_args import args
from app.logger import setup_logger
setup_logger(verbose=args.verbose)
setup_logger(log_level=args.verbose)
def execute_prestartup_script():
@@ -63,6 +63,7 @@ import threading
import gc
import logging
import utils.extra_config
if os.name == "nt":
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
@@ -85,7 +86,6 @@ if args.windows_standalone_build:
pass
import comfy.utils
import yaml
import execution
import server
@@ -160,7 +160,10 @@ def prompt_worker(q, server):
need_gc = False
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
addresses = []
for addr in address.split(","):
addresses.append((addr, port))
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
def hijack_progress(server):
@@ -180,27 +183,6 @@ def cleanup_temp():
shutil.rmtree(temp_dir, ignore_errors=True)
def load_extra_path_config(yaml_path):
with open(yaml_path, 'r') as stream:
config = yaml.safe_load(stream)
for c in config:
conf = config[c]
if conf is None:
continue
base_path = None
if "base_path" in conf:
base_path = conf.pop("base_path")
for x in conf:
for y in conf[x].split("\n"):
if len(y) == 0:
continue
full_path = y
if base_path is not None:
full_path = os.path.join(base_path, full_path)
logging.info("Adding extra search path {} {}".format(x, full_path))
folder_paths.add_model_folder_path(x, full_path)
if __name__ == "__main__":
if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
@@ -222,11 +204,11 @@ if __name__ == "__main__":
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path)
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path)
utils.extra_config.load_extra_path_config(config_path)
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
@@ -247,21 +229,30 @@ if __name__ == "__main__":
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)
logging.info(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir)
if args.user_directory:
user_dir = os.path.abspath(args.user_directory)
logging.info(f"Setting user directory to: {user_dir}")
folder_paths.set_user_directory(user_dir)
if args.quick_test_for_ci:
exit(0)
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
call_on_start = None
if args.auto_launch:
def startup_server(scheme, address, port):
import webbrowser
if os.name == 'nt' and address == '0.0.0.0':
address = '127.0.0.1'
if ':' in address:
address = "[{}]".format(address)
webbrowser.open(f"{scheme}://{address}:{port}")
call_on_start = startup_server

View File

@@ -1,2 +1,2 @@
# model_manager/__init__.py
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename

View File

@@ -1,9 +1,10 @@
#NOTE: This was an experiment and WILL BE REMOVED
from __future__ import annotations
import aiohttp
import os
import traceback
import logging
from folder_paths import models_dir
from folder_paths import folder_names_and_paths, get_folder_paths
import re
from typing import Callable, Any, Optional, Awaitable, Dict
from enum import Enum
@@ -17,6 +18,7 @@ class DownloadStatusType(Enum):
COMPLETED = "completed"
ERROR = "error"
@dataclass
class DownloadModelStatus():
status: str
@@ -29,7 +31,7 @@ class DownloadModelStatus():
self.progress_percentage = progress_percentage
self.message = message
self.already_existed = already_existed
def to_dict(self) -> Dict[str, Any]:
return {
"status": self.status,
@@ -38,102 +40,112 @@ class DownloadModelStatus():
"already_existed": self.already_existed
}
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str,
model_url: str,
model_sub_directory: str,
model_name: str,
model_url: str,
model_directory: str,
folder_path: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus:
"""
Download a model file from a given URL into the models directory.
Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str):
model_name (str):
The name of the model file to be downloaded. This will be the filename on disk.
model_url (str):
model_url (str):
The URL from which to download the model.
model_sub_directory (str):
The subdirectory within the main models directory where the model
model_directory (str):
The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates.
folder_path (str);
Path to which model folder should be used as the root.
Returns:
DownloadModelStatus: The result of the download operation.
"""
if not validate_model_subdirectory(model_sub_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model subdirectory",
False
)
if not validate_filename(model_name):
return DownloadModelStatus(
DownloadStatusType.ERROR,
DownloadStatusType.ERROR,
0,
"Invalid model name",
"Invalid model name",
False
)
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
if not model_directory in folder_names_and_paths:
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
False
)
if not folder_path in get_folder_paths(model_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
False
)
file_path = create_model_path(model_name, folder_path)
existing_file = await check_file_exists(file_path, model_name, progress_callback)
if existing_file:
return existing_file
try:
logging.info(f"Downloading {model_name} from {model_url}")
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
response = await model_download_request(model_url)
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
except Exception as e:
logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback, relative_path)
return await handle_download_error(e, model_name, progress_callback)
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
full_model_dir = os.path.join(models_base_dir, model_directory)
os.makedirs(full_model_dir, exist_ok=True)
file_path = os.path.join(full_model_dir, model_name)
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
os.makedirs(folder_path, exist_ok=True)
file_path = os.path.join(folder_path, model_name)
# Ensure the resulting path is still within the base directory
abs_file_path = os.path.abspath(file_path)
abs_base_dir = os.path.abspath(str(models_base_dir))
abs_base_dir = os.path.abspath(folder_path)
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
return file_path
relative_path = '/'.join([model_directory, model_name])
return file_path, relative_path
async def check_file_exists(file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str) -> Optional[DownloadModelStatus]:
async def check_file_exists(file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path):
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
return status
return None
async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str,
async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
interval: float = 1.0) -> DownloadModelStatus:
try:
total_size = int(response.headers.get('Content-Length', 0))
@@ -144,10 +156,11 @@ async def track_download_progress(response: aiohttp.ClientResponse,
nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
last_update_time = time.time()
with open(file_path, 'wb') as f:
temp_file_path = file_path + '.tmp'
with open(temp_file_path, 'wb') as f:
chunk_iterator = response.content.iter_chunked(8192)
while True:
try:
@@ -156,58 +169,39 @@ async def track_download_progress(response: aiohttp.ClientResponse,
break
f.write(chunk)
downloaded += len(chunk)
if time.time() - last_update_time >= interval:
await update_progress()
os.rename(temp_file_path, file_path)
await update_progress()
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
return status
except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback, relative_path)
return await handle_download_error(e, model_name, progress_callback)
async def handle_download_error(e: Exception,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any],
relative_path: str) -> DownloadModelStatus:
async def handle_download_error(e: Exception,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any]
) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
return status
def validate_model_subdirectory(model_subdirectory: str) -> bool:
"""
Validate that the model subdirectory is safe to install into.
Must not contain relative paths, nested paths or special characters
other than underscores and hyphens.
Args:
model_subdirectory (str): The subdirectory for the specific model type.
Returns:
bool: True if the subdirectory is safe, False otherwise.
"""
if len(model_subdirectory) > 50:
return False
if '..' in model_subdirectory or '/' in model_subdirectory:
return False
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
return False
return True
def validate_filename(filename: str)-> bool:
"""
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
Args:
filename (str): The filename to validate

View File

@@ -511,10 +511,11 @@ class CheckpointLoader:
FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders"
DEPRECATED = True
def load_checkpoint(self, config_name, ckpt_name):
config_path = folder_paths.get_full_path("configs", config_name)
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class CheckpointLoaderSimple:
@@ -535,7 +536,7 @@ class CheckpointLoaderSimple:
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
def load_checkpoint(self, ckpt_name):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out[:3]
@@ -577,7 +578,7 @@ class unCLIPCheckpointLoader:
CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out
@@ -624,7 +625,7 @@ class LoraLoader:
if strength_model == 0 and strength_clip == 0:
return (model, clip)
lora_path = folder_paths.get_full_path("loras", lora_name)
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None
if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path:
@@ -703,11 +704,11 @@ class VAELoader:
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]
@@ -738,7 +739,7 @@ class VAELoader:
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (vae,)
@@ -754,7 +755,7 @@ class ControlNetLoader:
CATEGORY = "loaders"
def load_controlnet(self, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
return (controlnet,)
@@ -770,7 +771,7 @@ class DiffControlNetLoader:
CATEGORY = "loaders"
def load_controlnet(self, model, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
return (controlnet,)
@@ -786,6 +787,7 @@ class ControlNetApply:
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_controlnet"
DEPRECATED = True
CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, conditioning, control_net, image, strength):
@@ -815,7 +817,10 @@ class ControlNetApplyAdvanced:
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
},
"optional": {"vae": ("VAE", ),
}
}
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
RETURN_NAMES = ("positive", "negative")
@@ -823,7 +828,7 @@ class ControlNetApplyAdvanced:
CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
if strength == 0:
return (positive, negative)
@@ -840,7 +845,7 @@ class ControlNetApplyAdvanced:
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat)
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
@@ -856,7 +861,7 @@ class UNETLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
@@ -867,10 +872,13 @@ class UNETLoader:
model_options = {}
if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e4m3fn_fast":
model_options["dtype"] = torch.float8_e4m3fn
model_options["fp8_optimizations"] = True
elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2
unet_path = folder_paths.get_full_path("diffusion_models", unet_name)
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,)
@@ -895,7 +903,7 @@ class CLIPLoader:
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
clip_path = folder_paths.get_full_path("clip", clip_name)
clip_path = folder_paths.get_full_path_or_raise("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,)
@@ -912,8 +920,8 @@ class DualCLIPLoader:
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2, type):
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
if type == "sdxl":
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3":
@@ -935,7 +943,7 @@ class CLIPVisionLoader:
CATEGORY = "loaders"
def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
clip_vision = comfy.clip_vision.load(clip_path)
return (clip_vision,)
@@ -965,7 +973,7 @@ class StyleModelLoader:
CATEGORY = "loaders"
def load_style_model(self, style_model_name):
style_model_path = folder_paths.get_full_path("style_models", style_model_name)
style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
style_model = comfy.sd.load_style_model(style_model_path)
return (style_model,)
@@ -1030,7 +1038,7 @@ class GLIGENLoader:
CATEGORY = "loaders"
def load_gligen(self, gligen_name):
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
gligen = comfy.sd.load_gligen(gligen_path)
return (gligen,)
@@ -1916,8 +1924,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ConditioningSetArea": "Conditioning (Set Area)",
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
"ConditioningSetMask": "Conditioning (Set Mask)",
"ControlNetApply": "Apply ControlNet",
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
"ControlNetApply": "Apply ControlNet (OLD)",
"ControlNetApplyAdvanced": "Apply ControlNet",
# Latent
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
"SetLatentNoiseMask": "Set Latent Noise Mask",
@@ -2101,6 +2109,8 @@ def init_builtin_extra_nodes():
"nodes_controlnet.py",
"nodes_hunyuan.py",
"nodes_flux.py",
"nodes_lora_extract.py",
"nodes_torch_compile.py",
]
import_failed = []

View File

@@ -38,6 +38,9 @@ def get_images(ws, prompt):
if data['node'] is None and data['prompt_id'] == prompt_id:
break #Execution is done
else:
# If you want to be able to decode the binary stream for latent previews, here is how you can do it:
# bytesIO = BytesIO(out[8:])
# preview_image = Image.open(bytesIO) # This is your preview in PIL image format, store it in a global
continue #previews are binary data
history = get_history(prompt_id)[prompt_id]
@@ -151,7 +154,7 @@ prompt["3"]["inputs"]["seed"] = 5
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, prompt)
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
#Commented out code to display the output images:
# for node_id in images:

View File

@@ -147,7 +147,7 @@ prompt["3"]["inputs"]["seed"] = 5
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, prompt)
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
#Commented out code to display the output images:
# for node_id in images:

128
server.py
View File

@@ -12,6 +12,8 @@ import json
import glob
import struct
import ssl
import socket
import ipaddress
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO
@@ -80,6 +82,68 @@ def create_cors_middleware(allowed_origin: str):
return cors_middleware
def is_loopback(host):
if host is None:
return False
try:
if ipaddress.ip_address(host).is_loopback:
return True
else:
return False
except:
pass
loopback = False
for family in (socket.AF_INET, socket.AF_INET6):
try:
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
for family, _, _, _, sockaddr in r:
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
return loopback
else:
loopback = True
except socket.gaierror:
pass
return loopback
def create_origin_only_middleware():
@web.middleware
async def origin_only_middleware(request: web.Request, handler):
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
#in that case the Host and Origin hostnames won't match
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
if 'Host' in request.headers and 'Origin' in request.headers:
host = request.headers['Host']
origin = request.headers['Origin']
host_domain = host.lower()
parsed = urllib.parse.urlparse(origin)
origin_domain = parsed.netloc.lower()
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
loopback = is_loopback(host_domain_parsed.hostname)
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
host_domain = host_domain_parsed.hostname
if host_domain_parsed.port is None:
origin_domain = parsed.hostname
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
if host_domain != origin_domain:
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
return web.Response(status=403)
if request.method == "OPTIONS":
response = web.Response()
else:
response = await handler(request)
return response
return origin_only_middleware
class PromptServer():
def __init__(self, loop):
PromptServer.instance = self
@@ -99,6 +163,8 @@ class PromptServer():
middlewares = [cache_control]
if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header))
else:
middlewares.append(create_origin_only_middleware())
max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
@@ -155,6 +221,12 @@ class PromptServer():
def get_embeddings(self):
embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@routes.get("/models")
def list_model_types(request):
model_types = list(folder_paths.folder_names_and_paths.keys())
return web.json_response(model_types)
@routes.get("/models/{folder}")
async def get_models(request):
@@ -418,12 +490,17 @@ class PromptServer():
async def system_stats(request):
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)
cpu_device = comfy.model_management.torch.device("cpu")
ram_total = comfy.model_management.get_total_memory(cpu_device)
ram_free = comfy.model_management.get_free_memory(cpu_device)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
system_stats = {
"system": {
"os": os.name,
"ram_total": ram_total,
"ram_free": ram_free,
"comfyui_version": get_comfyui_version(),
"python_version": sys.version,
"pytorch_version": comfy.model_management.torch_version,
@@ -480,14 +557,15 @@ class PromptServer():
@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
try:
out[x] = node_info(x)
except Exception as e:
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc())
return web.json_response(out)
with folder_paths.cache_helper:
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
try:
out[x] = node_info(x)
except Exception as e:
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc())
return web.json_response(out)
@routes.get("/object_info/{node_class}")
async def get_object_info_node(request):
@@ -601,6 +679,7 @@ class PromptServer():
# Internal route. Should not be depended upon and is subject to change at any time.
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
# NOTE: This was an experiment and WILL BE REMOVED
@routes.post("/internal/models/download")
async def download_handler(request):
async def report_progress(filename: str, status: DownloadModelStatus):
@@ -611,10 +690,11 @@ class PromptServer():
data = await request.json()
url = data.get('url')
model_directory = data.get('model_directory')
folder_path = data.get('folder_path')
model_filename = data.get('model_filename')
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
if not url or not model_directory or not model_filename:
if not url or not model_directory or not model_filename or not folder_path:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
session = self.client_session
@@ -622,7 +702,7 @@ class PromptServer():
logging.error("Client session is not initialized")
return web.Response(status=500)
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
await task
return web.json_response(task.result().to_dict())
@@ -739,6 +819,9 @@ class PromptServer():
await self.send(*msg)
async def start(self, address, port, verbose=True, call_on_start=None):
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
async def start_multi_address(self, addresses, call_on_start=None):
runner = web.AppRunner(self.app, access_log=None)
await runner.setup()
ssl_ctx = None
@@ -749,17 +832,26 @@ class PromptServer():
keyfile=args.tls_keyfile)
scheme = "https"
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
await site.start()
logging.info("Starting server\n")
for addr in addresses:
address = addr[0]
port = addr[1]
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
await site.start()
self.address = address
self.port = port
if not hasattr(self, 'address'):
self.address = address #TODO: remove this
self.port = port
if ':' in address:
address_print = "[{}]".format(address)
else:
address_print = address
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
if verbose:
logging.info("Starting server\n")
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
if call_on_start is not None:
call_on_start(scheme, address, port)
call_on_start(scheme, self.address, self.port)
def add_on_prompt_handler(self, handler):
self.on_prompt_handlers.append(handler)

View File

@@ -2,7 +2,7 @@
## Install test dependencies
`pip install -r tests-units/requirements.txt`
`pip install -r tests-unit/requirements.txt`
## Run tests
`pytest tests-units/`
`pytest tests-unit/`

View File

@@ -0,0 +1,66 @@
### 🗻 This file is created through the spirit of Mount Fuji at its peak
# TODO(yoland): clean up this after I get back down
import pytest
import os
import tempfile
from unittest.mock import patch
import folder_paths
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield tmpdirname
def test_get_directory_by_type():
test_dir = "/test/dir"
folder_paths.set_output_directory(test_dir)
assert folder_paths.get_directory_by_type("output") == test_dir
assert folder_paths.get_directory_by_type("invalid") is None
def test_annotated_filepath():
assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None)
assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory())
assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory())
assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory())
def test_get_annotated_filepath():
default_dir = "/default/dir"
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
def test_add_model_folder_path():
folder_paths.add_model_folder_path("test_folder", "/test/path")
assert "/test/path" in folder_paths.get_folder_paths("test_folder")
def test_recursive_search(temp_dir):
os.makedirs(os.path.join(temp_dir, "subdir"))
open(os.path.join(temp_dir, "file1.txt"), "w").close()
open(os.path.join(temp_dir, "subdir", "file2.txt"), "w").close()
files, dirs = folder_paths.recursive_search(temp_dir)
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
assert len(dirs) == 2 # temp_dir and subdir
def test_filter_files_extensions():
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"]
assert folder_paths.filter_files_extensions(files, []) == files
@patch("folder_paths.recursive_search")
@patch("folder_paths.folder_names_and_paths")
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
mock_folder_names_and_paths.__getitem__.return_value = (["/test/path"], {".txt"})
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
def test_get_save_image_path(temp_dir):
with patch("folder_paths.output_directory", temp_dir):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
assert os.path.samefile(full_output_folder, temp_dir)
assert filename == "test"
assert counter == 1
assert subfolder == ""
assert filename_prefix == "test"

View File

View File

@@ -0,0 +1,52 @@
import pytest
import os
import tempfile
from folder_paths import filter_files_content_types
@pytest.fixture(scope="module")
def file_extensions():
return {
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
}
@pytest.fixture(scope="module")
def mock_dir(file_extensions):
with tempfile.TemporaryDirectory() as directory:
for content_type, extensions in file_extensions.items():
for extension in extensions:
with open(f"{directory}/sample_{content_type}.{extension}", "w") as f:
f.write(f"Sample {content_type} file in {extension} format")
yield directory
def test_categorizes_all_correctly(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
for extension in extensions:
assert f"sample_{content_type}.{extension}" in filtered_files
def test_categorizes_all_uniquely(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
assert len(filtered_files) == len(extensions)
def test_handles_bad_extensions():
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
def test_handles_no_extension():
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
def test_handles_no_files():
files = []
assert filter_files_content_types(files, ["image", "audio", "video"]) == []

View File

@@ -1,10 +1,17 @@
import pytest
import tempfile
import aiohttp
from aiohttp import ClientResponse
import itertools
import os
import os
from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
import folder_paths
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield tmpdirname
class AsyncIteratorMock:
"""
@@ -42,7 +49,7 @@ class ContentMock:
return AsyncIteratorMock(self.chunks)
@pytest.mark.asyncio
async def test_download_model_success():
async def test_download_model_success(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.status = 200
mock_response.headers = {'Content-Length': '1000'}
@@ -53,15 +60,13 @@ async def test_download_model_success():
mock_make_request = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
# Mock file operations
mock_open = MagicMock()
mock_file = MagicMock()
mock_open.return_value.__enter__.return_value = mock_file
time_values = itertools.count(0, 0.1)
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('builtins.open', mock_open), \
patch('folder_paths.folder_names_and_paths', fake_paths), \
patch('time.time', side_effect=time_values): # Simulate time passing
result = await download_model(
@@ -69,6 +74,7 @@ async def test_download_model_success():
'model.sft',
'http://example.com/model.sft',
'checkpoints',
temp_dir,
mock_progress_callback
)
@@ -83,44 +89,48 @@ async def test_download_model_success():
# Check initial call
mock_progress_callback.assert_any_call(
'checkpoints/model.sft',
'model.sft',
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
)
# Check final call
mock_progress_callback.assert_any_call(
'checkpoints/model.sft',
'model.sft',
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
)
# Verify file writing
mock_file.write.assert_any_call(b'a' * 500)
mock_file.write.assert_any_call(b'b' * 300)
mock_file.write.assert_any_call(b'c' * 200)
mock_file_path = os.path.join(temp_dir, 'model.sft')
assert os.path.exists(mock_file_path)
with open(mock_file_path, 'rb') as mock_file:
assert mock_file.read() == b''.join(chunks)
os.remove(mock_file_path)
# Verify request was made
mock_make_request.assert_called_once_with('http://example.com/model.sft')
@pytest.mark.asyncio
async def test_download_model_url_request_failure():
async def test_download_model_url_request_failure(temp_dir):
# Mock dependencies
mock_response = AsyncMock(spec=ClientResponse)
mock_response.status = 404 # Simulate a "Not Found" error
mock_get = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
# Mock the create_model_path function
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
# Mock the check_file_exists function to return None (file doesn't exist)
with patch('model_filemanager.check_file_exists', return_value=None):
# Call the function
result = await download_model(
mock_get,
'model.safetensors',
'http://example.com/model.safetensors',
'mock_directory',
mock_progress_callback
)
with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('folder_paths.folder_names_and_paths', fake_paths):
# Call the function
result = await download_model(
mock_get,
'model.safetensors',
'http://example.com/model.safetensors',
'checkpoints',
temp_dir,
mock_progress_callback
)
# Assert the expected behavior
assert isinstance(result, DownloadModelStatus)
@@ -130,7 +140,7 @@ async def test_download_model_url_request_failure():
# Check that progress_callback was called with the correct arguments
mock_progress_callback.assert_any_call(
'mock_directory/model.safetensors',
'model.safetensors',
DownloadModelStatus(
status=DownloadStatusType.PENDING,
progress_percentage=0,
@@ -139,7 +149,7 @@ async def test_download_model_url_request_failure():
)
)
mock_progress_callback.assert_called_with(
'mock_directory/model.safetensors',
'model.safetensors',
DownloadModelStatus(
status=DownloadStatusType.ERROR,
progress_percentage=0,
@@ -153,98 +163,125 @@ async def test_download_model_url_request_failure():
@pytest.mark.asyncio
async def test_download_model_invalid_model_subdirectory():
mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock()
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'../bad_path',
'../bad_path',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message == 'Invalid model subdirectory'
assert result.message.startswith('Invalid or unrecognized model directory')
assert result.status == 'error'
assert result.already_existed is False
@pytest.mark.asyncio
async def test_download_model_invalid_folder_path():
mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock()
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'checkpoints',
'invalid_path',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message.startswith("Invalid folder path")
assert result.status == 'error'
assert result.already_existed is False
# For create_model_path function
def test_create_model_path(tmp_path, monkeypatch):
mock_models_dir = tmp_path / "models"
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
model_name = "test_model.sft"
model_directory = "test_dir"
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
assert file_path == str(mock_models_dir / model_directory / model_name)
assert relative_path == f"{model_directory}/{model_name}"
model_name = "model.safetensors"
folder_path = os.path.join(tmp_path, "mock_dir")
file_path = create_model_path(model_name, folder_path)
assert file_path == os.path.join(folder_path, "model.safetensors")
assert os.path.exists(os.path.dirname(file_path))
with pytest.raises(Exception, match="Invalid model directory"):
create_model_path("../path_traversal.safetensors", folder_path)
with pytest.raises(Exception, match="Invalid model directory"):
create_model_path("/etc/some_root_path", folder_path)
@pytest.mark.asyncio
async def test_check_file_exists_when_file_exists(tmp_path):
file_path = tmp_path / "existing_model.sft"
file_path.touch() # Create an empty file
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback)
assert result is not None
assert result.status == "completed"
assert result.message == "existing_model.sft already exists"
assert result.already_existed is True
mock_callback.assert_called_once_with(
"test/existing_model.sft",
"existing_model.sft",
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
)
@pytest.mark.asyncio
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
file_path = tmp_path / "non_existing_model.sft"
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback)
assert result is None
mock_callback.assert_not_called()
@pytest.mark.asyncio
async def test_track_download_progress_no_content_length():
async def test_track_download_progress_no_content_length(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {} # No Content-Length header
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
chunks = [b'a' * 500, b'b' * 500]
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
with patch('builtins.open', mock_open):
result = await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, 'models/model.sft', interval=0.1
)
full_path = os.path.join(temp_dir, 'model.sft')
result = await track_download_progress(
mock_response, full_path, 'model.sft',
mock_callback, interval=0.1
)
assert result.status == "completed"
assert os.path.exists(full_path)
with open(full_path, 'rb') as f:
assert f.read() == b''.join(chunks)
os.remove(full_path)
# Check that progress was reported even without knowing the total size
mock_callback.assert_any_call(
'models/model.sft',
'model.sft',
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
)
@pytest.mark.asyncio
async def test_track_download_progress_interval():
async def test_track_download_progress_interval(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {'Content-Length': '1000'}
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
chunks = [b'a' * 100] * 10
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
@@ -253,18 +290,18 @@ async def test_track_download_progress_interval():
mock_time = MagicMock()
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
with patch('builtins.open', mock_open), \
patch('time.time', mock_time):
await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, 'models/model.sft', interval=1.0
)
full_path = os.path.join(temp_dir, 'model.sft')
# Print out the actual call count and the arguments of each call for debugging
print(f"mock_callback was called {mock_callback.call_count} times")
for i, call in enumerate(mock_callback.call_args_list):
args, kwargs = call
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
with patch('time.time', mock_time):
await track_download_progress(
mock_response, full_path, 'model.sft',
mock_callback, interval=1.0
)
assert os.path.exists(full_path)
with open(full_path, 'rb') as f:
assert f.read() == b''.join(chunks)
os.remove(full_path)
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
@@ -279,27 +316,6 @@ async def test_track_download_progress_interval():
assert last_call[0][1].status == "completed"
assert last_call[0][1].progress_percentage == 100
def test_valid_subdirectory():
assert validate_model_subdirectory("valid-model123") is True
def test_subdirectory_too_long():
assert validate_model_subdirectory("a" * 51) is False
def test_subdirectory_with_double_dots():
assert validate_model_subdirectory("model/../unsafe") is False
def test_subdirectory_with_slash():
assert validate_model_subdirectory("model/unsafe") is False
def test_subdirectory_with_special_characters():
assert validate_model_subdirectory("model@unsafe") is False
def test_subdirectory_with_underscore_and_dash():
assert validate_model_subdirectory("valid_model-name") is True
def test_empty_subdirectory():
assert validate_model_subdirectory("") is False
@pytest.mark.parametrize("filename, expected", [
("valid_model.safetensors", True),
("valid_model.sft", True),

View File

@@ -0,0 +1,120 @@
import pytest
import os
from aiohttp import web
from app.user_manager import UserManager
from unittest.mock import patch
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def user_manager(tmp_path):
um = UserManager()
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
tmp_path, file
)
return um
@pytest.fixture
def app(user_manager):
app = web.Application()
routes = web.RouteTableDef()
user_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 404
async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 200
assert await resp.json() == ["file1.txt"]
async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true")
assert resp.status == 200
assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"}
async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&full_info=true")
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert result[0]["path"] == "file1.txt"
assert "size" in result[0]
assert "modified" in result[0]
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
assert resp.status == 200
assert await resp.json() == [
["subdir/file1.txt", "subdir", "file1.txt"]
]
async def test_listuserdata_invalid_directory(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=")
assert resp.status == 400
async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
os_sep = "\\"
with patch("os.sep", os_sep):
with patch("os.path.sep", os_sep):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true")
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert "/" in result[0] # Ensure forward slash is used
assert "\\" not in result[0] # Ensure backslash is not present
assert result[0] == "subdir/file1.txt"
# Test with full_info
resp = await client.get(
"/userdata?dir=test_dir&recurse=true&full_info=true"
)
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert "/" in result[0]["path"] # Ensure forward slash is used
assert "\\" not in result[0]["path"] # Ensure backslash is not present
assert result[0]["path"] == "subdir/file1.txt"

View File

@@ -0,0 +1,126 @@
import pytest
import yaml
import os
from unittest.mock import Mock, patch, mock_open
from utils.extra_config import load_extra_path_config
import folder_paths
@pytest.fixture
def mock_yaml_content():
return {
'test_config': {
'base_path': '~/App/',
'checkpoints': 'subfolder1',
}
}
@pytest.fixture
def mock_expanded_home():
return '/home/user'
@pytest.fixture
def yaml_config_with_appdata():
return """
test_config:
base_path: '%APPDATA%/ComfyUI'
checkpoints: 'models/checkpoints'
"""
@pytest.fixture
def mock_yaml_content_appdata(yaml_config_with_appdata):
return yaml.safe_load(yaml_config_with_appdata)
@pytest.fixture
def mock_expandvars_appdata():
mock = Mock()
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
return mock
@pytest.fixture
def mock_add_model_folder_path():
return Mock()
@pytest.fixture
def mock_expanduser(mock_expanded_home):
def _expanduser(path):
if path.startswith('~/'):
return os.path.join(mock_expanded_home, path[2:])
return path
return _expanduser
@pytest.fixture
def mock_yaml_safe_load(mock_yaml_content):
return Mock(return_value=mock_yaml_content)
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
def test_load_extra_model_paths_expands_userpath(
mock_file,
monkeypatch,
mock_add_model_folder_path,
mock_expanduser,
mock_yaml_safe_load,
mock_expanded_home
):
# Attach mocks used by load_extra_path_config
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
monkeypatch.setattr(os.path, 'expanduser', mock_expanduser)
monkeypatch.setattr(yaml, 'safe_load', mock_yaml_safe_load)
dummy_yaml_file_name = 'dummy_path.yaml'
load_extra_path_config(dummy_yaml_file_name)
expected_calls = [
('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1'), False),
]
assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check if add_model_folder_path was called with the correct arguments
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args[0] == expected_call[0]
assert os.path.normpath(actual_call.args[1]) == os.path.normpath(expected_call[1]) # Normalize and check the path to check on multiple OS.
assert actual_call.args[2] == expected_call[2]
# Check if yaml.safe_load was called
mock_yaml_safe_load.assert_called_once()
# Check if open was called with the correct file path
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
@patch('builtins.open', new_callable=mock_open)
def test_load_extra_model_paths_expands_appdata(
mock_file,
monkeypatch,
mock_add_model_folder_path,
mock_expandvars_appdata,
yaml_config_with_appdata,
mock_yaml_content_appdata
):
# Set the mock_file to return yaml with appdata as a variable
mock_file.return_value.read.return_value = yaml_config_with_appdata
# Attach mocks
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
monkeypatch.setattr(os.path, 'expandvars', mock_expandvars_appdata)
monkeypatch.setattr(yaml, 'safe_load', Mock(return_value=mock_yaml_content_appdata))
# Mock expanduser to do nothing (since we're not testing it here)
monkeypatch.setattr(os.path, 'expanduser', lambda x: x)
dummy_yaml_file_name = 'dummy_path.yaml'
load_extra_path_config(dummy_yaml_file_name)
expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
expected_calls = [
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False),
]
assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check the base path variable was expanded
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args == expected_call
# Verify that expandvars was called
assert mock_expandvars_appdata.called

View File

@@ -496,3 +496,29 @@ class TestExecution:
assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
assert not result.did_run(test_node), "The execution should have been cached"
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
# only that one entry in the list is blocked.
def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
image3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image_list = g.node("TestMakeListNode", value1=image1.out(0), value2=image2.out(0), value3=image3.out(0))
int1 = g.node("StubInt", value=1)
int2 = g.node("StubInt", value=2)
int3 = g.node("StubInt", value=3)
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
output = g.node("PreviewImage", images=list_output.out(0))
result = client.run(g)
assert result.did_run(output), "The execution should have run"
images = result.get_images(output)
assert len(images) == 2, "Should have 2 images"
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"

0
utils/__init__.py Normal file
View File

28
utils/extra_config.py Normal file
View File

@@ -0,0 +1,28 @@
import os
import yaml
import folder_paths
import logging
def load_extra_path_config(yaml_path):
with open(yaml_path, 'r') as stream:
config = yaml.safe_load(stream)
for c in config:
conf = config[c]
if conf is None:
continue
base_path = None
if "base_path" in conf:
base_path = conf.pop("base_path")
base_path = os.path.expandvars(os.path.expanduser(base_path))
is_default = False
if "is_default" in conf:
is_default = conf.pop("is_default")
for x in conf:
for y in conf[x].split("\n"):
if len(y) == 0:
continue
full_path = y
if base_path is not None:
full_path = os.path.join(base_path, full_path)
logging.info("Adding extra search path {} {}".format(x, full_path))
folder_paths.add_model_folder_path(x, full_path, is_default)

1
web/assets/CREDIT.txt generated vendored Normal file
View File

@@ -0,0 +1 @@
Thanks to OpenArt (https://openart.ai) for providing the sorted-custom-node-map data, captured in September 2024.

792
web/assets/GraphView-BGt8GmeB.css generated vendored Normal file
View File

@@ -0,0 +1,792 @@
.editable-text[data-v-54da6fc9] {
display: inline;
}
.editable-text input[data-v-54da6fc9] {
width: 100%;
box-sizing: border-box;
}
.group-title-editor.node-title-editor[data-v-fc3f26e3] {
z-index: 9999;
padding: 0.25rem;
}
[data-v-fc3f26e3] .editable-text {
width: 100%;
height: 100%;
}
[data-v-fc3f26e3] .editable-text input {
width: 100%;
height: 100%;
/* Override the default font size */
font-size: inherit;
}
.side-bar-button-icon {
font-size: var(--sidebar-icon-size) !important;
}
.side-bar-button-selected .side-bar-button-icon {
font-size: var(--sidebar-icon-size) !important;
font-weight: bold;
}
.side-bar-button[data-v-caa3ee9c] {
width: var(--sidebar-width);
height: var(--sidebar-width);
border-radius: 0;
}
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
border-left: 4px solid var(--p-button-text-primary-color);
}
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
border-right: 4px solid var(--p-button-text-primary-color);
}
:root {
--sidebar-width: 64px;
--sidebar-icon-size: 1.5rem;
}
:root .small-sidebar {
--sidebar-width: 40px;
--sidebar-icon-size: 1rem;
}
.side-tool-bar-container[data-v-4da64512] {
display: flex;
flex-direction: column;
align-items: center;
pointer-events: auto;
width: var(--sidebar-width);
height: 100%;
background-color: var(--comfy-menu-bg);
color: var(--fg-color);
}
.side-tool-bar-end[data-v-4da64512] {
align-self: flex-end;
margin-top: auto;
}
.sidebar-content-container[data-v-4da64512] {
height: 100%;
overflow-y: auto;
}
.p-splitter-gutter {
pointer-events: auto;
}
.gutter-hidden {
display: none !important;
}
.side-bar-panel[data-v-b9df3042] {
background-color: var(--bg-color);
pointer-events: auto;
}
.splitter-overlay[data-v-b9df3042] {
width: 100%;
height: 100%;
position: absolute;
top: 0;
left: 0;
background-color: transparent;
pointer-events: none;
/* Set it the same as the ComfyUI menu */
/* Note: Lite-graph DOM widgets have the same z-index as the node id, so
999 should be sufficient to make sure splitter overlays on node's DOM
widgets */
z-index: 999;
border: none;
}
._content[data-v-e7b35fd9] {
display: flex;
flex-direction: column
}
._content[data-v-e7b35fd9] > :not([hidden]) ~ :not([hidden]) {
--tw-space-y-reverse: 0;
margin-top: calc(0.5rem * calc(1 - var(--tw-space-y-reverse)));
margin-bottom: calc(0.5rem * var(--tw-space-y-reverse))
}
._footer[data-v-e7b35fd9] {
display: flex;
flex-direction: column;
align-items: flex-end;
padding-top: 1rem
}
[data-v-37f672ab] .highlight {
background-color: var(--p-primary-color);
color: var(--p-primary-contrast-color);
font-weight: bold;
border-radius: 0.25rem;
padding: 0rem 0.125rem;
margin: -0.125rem 0.125rem;
}
.slot_row[data-v-ff07c900] {
padding: 2px;
}
/* Original N-Sidebar styles */
._sb_dot[data-v-ff07c900] {
width: 8px;
height: 8px;
border-radius: 50%;
background-color: grey;
}
.node_header[data-v-ff07c900] {
line-height: 1;
padding: 8px 13px 7px;
margin-bottom: 5px;
font-size: 15px;
text-wrap: nowrap;
overflow: hidden;
display: flex;
align-items: center;
}
.headdot[data-v-ff07c900] {
width: 10px;
height: 10px;
float: inline-start;
margin-right: 8px;
}
.IMAGE[data-v-ff07c900] {
background-color: #64b5f6;
}
.VAE[data-v-ff07c900] {
background-color: #ff6e6e;
}
.LATENT[data-v-ff07c900] {
background-color: #ff9cf9;
}
.MASK[data-v-ff07c900] {
background-color: #81c784;
}
.CONDITIONING[data-v-ff07c900] {
background-color: #ffa931;
}
.CLIP[data-v-ff07c900] {
background-color: #ffd500;
}
.MODEL[data-v-ff07c900] {
background-color: #b39ddb;
}
.CONTROL_NET[data-v-ff07c900] {
background-color: #a5d6a7;
}
._sb_node_preview[data-v-ff07c900] {
background-color: var(--comfy-menu-bg);
font-family: 'Open Sans', sans-serif;
font-size: small;
color: var(--descrip-text);
border: 1px solid var(--descrip-text);
min-width: 300px;
width: -moz-min-content;
width: min-content;
height: -moz-fit-content;
height: fit-content;
z-index: 9999;
border-radius: 12px;
overflow: hidden;
font-size: 12px;
padding-bottom: 10px;
}
._sb_node_preview ._sb_description[data-v-ff07c900] {
margin: 10px;
padding: 6px;
background: var(--border-color);
border-radius: 5px;
font-style: italic;
font-weight: 500;
font-size: 0.9rem;
word-break: break-word;
}
._sb_table[data-v-ff07c900] {
display: grid;
grid-column-gap: 10px;
/* Spazio tra le colonne */
width: 100%;
/* Imposta la larghezza della tabella al 100% del contenitore */
}
._sb_row[data-v-ff07c900] {
display: grid;
grid-template-columns: 10px 1fr 1fr 1fr 10px;
grid-column-gap: 10px;
align-items: center;
padding-left: 9px;
padding-right: 9px;
}
._sb_row_string[data-v-ff07c900] {
grid-template-columns: 10px 1fr 1fr 10fr 1fr;
}
._sb_col[data-v-ff07c900] {
border: 0px solid #000;
display: flex;
align-items: flex-end;
flex-direction: row-reverse;
flex-wrap: nowrap;
align-content: flex-start;
justify-content: flex-end;
}
._sb_inherit[data-v-ff07c900] {
display: inherit;
}
._long_field[data-v-ff07c900] {
background: var(--bg-color);
border: 2px solid var(--border-color);
margin: 5px 5px 0 5px;
border-radius: 10px;
line-height: 1.7;
text-wrap: nowrap;
}
._sb_arrow[data-v-ff07c900] {
color: var(--fg-color);
}
._sb_preview_badge[data-v-ff07c900] {
text-align: center;
background: var(--comfy-input-bg);
font-weight: bold;
color: var(--error-text);
}
.comfy-vue-node-search-container[data-v-2d409367] {
display: flex;
width: 100%;
min-width: 26rem;
align-items: center;
justify-content: center;
}
.comfy-vue-node-search-container[data-v-2d409367] * {
pointer-events: auto;
}
.comfy-vue-node-preview-container[data-v-2d409367] {
position: absolute;
left: -350px;
top: 50px;
}
.comfy-vue-node-search-box[data-v-2d409367] {
z-index: 10;
flex-grow: 1;
}
._filter-button[data-v-2d409367] {
z-index: 10;
}
._dialog[data-v-2d409367] {
min-width: 26rem;
}
.invisible-dialog-root {
width: 60%;
min-width: 24rem;
max-width: 48rem;
border: 0 !important;
background-color: transparent !important;
margin-top: 25vh;
margin-left: 400px;
}
@media all and (max-width: 768px) {
.invisible-dialog-root {
margin-left: 0px;
}
}
.node-search-box-dialog-mask {
align-items: flex-start !important;
}
.node-tooltip[data-v-0a4402f9] {
background: var(--comfy-input-bg);
border-radius: 5px;
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
color: var(--input-text);
font-family: sans-serif;
left: 0;
max-width: 30vw;
padding: 4px 8px;
position: absolute;
top: 0;
transform: translate(5px, calc(-100% - 5px));
white-space: pre-wrap;
z-index: 99999;
}
.p-buttongroup-vertical[data-v-ce8bd6ac] {
display: flex;
flex-direction: column;
border-radius: var(--p-button-border-radius);
overflow: hidden;
border: 1px solid var(--p-panel-border-color);
}
.p-buttongroup-vertical .p-button[data-v-ce8bd6ac] {
margin: 0;
border-radius: 0;
}
.comfy-image-wrap[data-v-9bc23daf] {
display: contents;
}
.comfy-image-blur[data-v-9bc23daf] {
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
-o-object-fit: cover;
object-fit: cover;
}
.comfy-image-main[data-v-9bc23daf] {
width: 100%;
height: 100%;
-o-object-fit: cover;
object-fit: cover;
-o-object-position: center;
object-position: center;
z-index: 1;
}
.contain .comfy-image-wrap[data-v-9bc23daf] {
position: relative;
width: 100%;
height: 100%;
}
.contain .comfy-image-main[data-v-9bc23daf] {
-o-object-fit: contain;
object-fit: contain;
-webkit-backdrop-filter: blur(10px);
backdrop-filter: blur(10px);
position: absolute;
}
.broken-image-placeholder[data-v-9bc23daf] {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
width: 100%;
height: 100%;
margin: 2rem;
}
.broken-image-placeholder i[data-v-9bc23daf] {
font-size: 3rem;
margin-bottom: 0.5rem;
}
.result-container[data-v-d9c060ae] {
width: 100%;
height: 100%;
aspect-ratio: 1 / 1;
overflow: hidden;
position: relative;
display: flex;
justify-content: center;
align-items: center;
}
.image-preview-mask[data-v-d9c060ae] {
position: absolute;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
display: flex;
align-items: center;
justify-content: center;
opacity: 0;
transition: opacity 0.3s ease;
z-index: 1;
}
.result-container:hover .image-preview-mask[data-v-d9c060ae] {
opacity: 1;
}
.task-result-preview[data-v-d4c8a1fe] {
aspect-ratio: 1 / 1;
overflow: hidden;
display: flex;
justify-content: center;
align-items: center;
width: 100%;
height: 100%;
}
.task-result-preview i[data-v-d4c8a1fe],
.task-result-preview span[data-v-d4c8a1fe] {
font-size: 2rem;
}
.task-item[data-v-d4c8a1fe] {
display: flex;
flex-direction: column;
border-radius: 4px;
overflow: hidden;
position: relative;
}
.task-item-details[data-v-d4c8a1fe] {
position: absolute;
bottom: 0;
padding: 0.6rem;
display: flex;
justify-content: space-between;
align-items: center;
width: 100%;
z-index: 1;
}
.task-node-link[data-v-d4c8a1fe] {
padding: 2px;
}
/* In dark mode, transparent background color for tags is not ideal for tags that
are floating on top of images. */
.tag-wrapper[data-v-d4c8a1fe] {
background-color: var(--p-primary-contrast-color);
border-radius: 6px;
display: inline-flex;
}
.node-name-tag[data-v-d4c8a1fe] {
word-break: break-all;
}
.status-tag-group[data-v-d4c8a1fe] {
display: flex;
flex-direction: column;
}
.progress-preview-img[data-v-d4c8a1fe] {
width: 100%;
height: 100%;
-o-object-fit: cover;
object-fit: cover;
-o-object-position: center;
object-position: center;
}
/* PrimeVue's galleria teleports the fullscreen gallery out of subtree so we
cannot use scoped style here. */
img.galleria-image {
max-width: 100vw;
max-height: 100vh;
-o-object-fit: contain;
object-fit: contain;
}
.p-galleria-close-button {
/* Set z-index so the close button doesn't get hidden behind the image when image is large */
z-index: 1;
}
.comfy-vue-side-bar-container[data-v-1b0a8fe3] {
display: flex;
flex-direction: column;
height: 100%;
overflow: hidden;
}
.comfy-vue-side-bar-header[data-v-1b0a8fe3] {
flex-shrink: 0;
border-left: none;
border-right: none;
border-top: none;
border-radius: 0;
padding: 0.25rem 1rem;
min-height: 2.5rem;
}
.comfy-vue-side-bar-header-span[data-v-1b0a8fe3] {
font-size: small;
}
.comfy-vue-side-bar-body[data-v-1b0a8fe3] {
flex-grow: 1;
overflow: auto;
scrollbar-width: thin;
scrollbar-color: transparent transparent;
}
.comfy-vue-side-bar-body[data-v-1b0a8fe3]::-webkit-scrollbar {
width: 1px;
}
.comfy-vue-side-bar-body[data-v-1b0a8fe3]::-webkit-scrollbar-thumb {
background-color: transparent;
}
.scroll-container[data-v-08fa89b1] {
height: 100%;
overflow-y: auto;
}
.queue-grid[data-v-08fa89b1] {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
padding: 0.5rem;
gap: 0.5rem;
}
.tree-node[data-v-633e27ab] {
width: 100%;
display: flex;
align-items: center;
justify-content: space-between;
}
.leaf-count-badge[data-v-633e27ab] {
margin-left: 0.5rem;
}
.node-content[data-v-633e27ab] {
display: flex;
align-items: center;
flex-grow: 1;
}
.leaf-label[data-v-633e27ab] {
margin-left: 0.5rem;
}
[data-v-633e27ab] .editable-text span {
word-break: break-all;
}
[data-v-bd7bae90] .tree-explorer-node-label {
width: 100%;
display: flex;
align-items: center;
margin-left: var(--p-tree-node-gap);
flex-grow: 1;
}
/*
* The following styles are necessary to avoid layout shift when dragging nodes over folders.
* By setting the position to relative on the parent and using an absolutely positioned pseudo-element,
* we can create a visual indicator for the drop target without affecting the layout of other elements.
*/
[data-v-bd7bae90] .p-tree-node-content:has(.tree-folder) {
position: relative;
}
[data-v-bd7bae90] .p-tree-node-content:has(.tree-folder.can-drop)::after {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
border: 1px solid var(--p-content-color);
pointer-events: none;
}
.node-lib-node-container[data-v-90dfee08] {
height: 100%;
width: 100%
}
.p-selectbutton .p-button[data-v-91077f2a] {
padding: 0.5rem;
}
.p-selectbutton .p-button .pi[data-v-91077f2a] {
font-size: 1.5rem;
}
.field[data-v-91077f2a] {
display: flex;
flex-direction: column;
gap: 0.5rem;
}
.color-picker-container[data-v-91077f2a] {
display: flex;
align-items: center;
gap: 0.5rem;
}
.node-lib-filter-popup {
margin-left: -13px;
}
[data-v-f6a7371a] .comfy-vue-side-bar-body {
background: var(--p-tree-background);
}
[data-v-f6a7371a] .node-lib-bookmark-tree-explorer {
padding-bottom: 2px;
}
[data-v-f6a7371a] .p-divider {
margin: var(--comfy-tree-explorer-item-padding) 0px;
}
.model_preview[data-v-32e6c4d9] {
background-color: var(--comfy-menu-bg);
font-family: 'Open Sans', sans-serif;
color: var(--descrip-text);
border: 1px solid var(--descrip-text);
min-width: 300px;
max-width: 500px;
width: -moz-fit-content;
width: fit-content;
height: -moz-fit-content;
height: fit-content;
z-index: 9999;
border-radius: 12px;
overflow: hidden;
font-size: 12px;
padding: 10px;
}
.model_preview_image[data-v-32e6c4d9] {
margin: auto;
width: -moz-fit-content;
width: fit-content;
}
.model_preview_image img[data-v-32e6c4d9] {
max-width: 100%;
max-height: 150px;
-o-object-fit: contain;
object-fit: contain;
}
.model_preview_title[data-v-32e6c4d9] {
font-weight: bold;
text-align: center;
font-size: 14px;
}
.model_preview_top_container[data-v-32e6c4d9] {
text-align: center;
line-height: 0.5;
}
.model_preview_filename[data-v-32e6c4d9],
.model_preview_author[data-v-32e6c4d9],
.model_preview_architecture[data-v-32e6c4d9] {
display: inline-block;
text-align: center;
margin: 5px;
font-size: 10px;
}
.model_preview_prefix[data-v-32e6c4d9] {
font-weight: bold;
}
.model-lib-model-icon-container[data-v-70b69131] {
display: inline-block;
position: relative;
left: 0;
height: 1.5rem;
vertical-align: top;
width: 0px;
}
.model-lib-model-icon[data-v-70b69131] {
background-size: cover;
background-position: center;
display: inline-block;
position: relative;
left: -2.5rem;
height: 2rem;
width: 2rem;
vertical-align: top;
}
.pi-fake-spacer {
height: 1px;
width: 16px;
}
[data-v-74b01bce] .comfy-vue-side-bar-body {
background: var(--p-tree-background);
}
[data-v-d2d58252] .comfy-vue-side-bar-body {
background: var(--p-tree-background);
}
[data-v-84e785b8] .p-togglebutton::before {
display: none
}
[data-v-84e785b8] .p-togglebutton {
position: relative;
flex-shrink: 0;
border-radius: 0px;
background-color: transparent;
padding-left: 0.5rem;
padding-right: 0.5rem
}
[data-v-84e785b8] .p-togglebutton.p-togglebutton-checked {
border-bottom-width: 2px;
border-bottom-color: var(--p-button-text-primary-color)
}
[data-v-84e785b8] .p-togglebutton-checked .close-button,[data-v-84e785b8] .p-togglebutton:hover .close-button {
visibility: visible
}
.status-indicator[data-v-84e785b8] {
position: absolute;
font-weight: 700;
font-size: 1.5rem;
top: 50%;
left: 50%;
transform: translate(-50%, -50%)
}
[data-v-84e785b8] .p-togglebutton:hover .status-indicator {
display: none
}
[data-v-84e785b8] .p-togglebutton .close-button {
visibility: hidden
}
.top-menubar[data-v-2ec1b620] .p-menubar-item-link svg {
display: none;
}
[data-v-2ec1b620] .p-menubar-submenu.dropdown-direction-up {
top: auto;
bottom: 100%;
flex-direction: column-reverse;
}
.keybinding-tag[data-v-2ec1b620] {
background: var(--p-content-hover-background);
border-color: var(--p-content-border-color);
border-style: solid;
}
[data-v-713442be] .p-inputtext {
border-top-left-radius: 0;
border-bottom-left-radius: 0;
}
.comfyui-queue-button[data-v-fcd3efcd] .p-splitbutton-dropdown {
border-top-right-radius: 0;
border-bottom-right-radius: 0;
}
.actionbar[data-v-bc6c78dd] {
pointer-events: all;
position: fixed;
z-index: 1000;
}
.actionbar.is-docked[data-v-bc6c78dd] {
position: static;
border-style: none;
background-color: transparent;
padding: 0px;
}
.actionbar.is-dragging[data-v-bc6c78dd] {
-webkit-user-select: none;
-moz-user-select: none;
user-select: none;
}
[data-v-bc6c78dd] .p-panel-content {
padding: 0.25rem;
}
[data-v-bc6c78dd] .p-panel-header {
display: none;
}
.comfyui-menu[data-v-b13fdc92] {
width: 100vw;
background: var(--comfy-menu-bg);
color: var(--fg-color);
font-family: Arial, Helvetica, sans-serif;
font-size: 0.8em;
box-sizing: border-box;
z-index: 1000;
order: 0;
grid-column: 1/-1;
max-height: 90vh;
}
.comfyui-menu.dropzone[data-v-b13fdc92] {
background: var(--p-highlight-background);
}
.comfyui-menu.dropzone-active[data-v-b13fdc92] {
background: var(--p-highlight-background-focus);
}
.comfyui-logo[data-v-b13fdc92] {
font-size: 1.2em;
-webkit-user-select: none;
-moz-user-select: none;
user-select: none;
cursor: default;
}

17465
web/assets/GraphView-CVV2XJjS.js generated vendored Normal file

File diff suppressed because one or more lines are too long

1
web/assets/GraphView-CVV2XJjS.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

865
web/assets/colorPalette-D5oi2-2V.js generated vendored Normal file
View File

@@ -0,0 +1,865 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { k as app, aP as LGraphCanvas, bO as useToastStore, ca as $el, z as LiteGraph } from "./index-DGAbdBYF.js";
const colorPalettes = {
dark: {
id: "dark",
name: "Dark (Default)",
colors: {
node_slot: {
CLIP: "#FFD500",
// bright yellow
CLIP_VISION: "#A8DADC",
// light blue-gray
CLIP_VISION_OUTPUT: "#ad7452",
// rusty brown-orange
CONDITIONING: "#FFA931",
// vibrant orange-yellow
CONTROL_NET: "#6EE7B7",
// soft mint green
IMAGE: "#64B5F6",
// bright sky blue
LATENT: "#FF9CF9",
// light pink-purple
MASK: "#81C784",
// muted green
MODEL: "#B39DDB",
// light lavender-purple
STYLE_MODEL: "#C2FFAE",
// light green-yellow
VAE: "#FF6E6E",
// bright red
NOISE: "#B0B0B0",
// gray
GUIDER: "#66FFFF",
// cyan
SAMPLER: "#ECB4B4",
// very soft red
SIGMAS: "#CDFFCD",
// soft lime green
TAESD: "#DCC274"
// cheesecake
},
litegraph_base: {
BACKGROUND_IMAGE: "",
CLEAR_BACKGROUND_COLOR: "#222",
NODE_TITLE_COLOR: "#999",
NODE_SELECTED_TITLE_COLOR: "#FFF",
NODE_TEXT_SIZE: 14,
NODE_TEXT_COLOR: "#AAA",
NODE_SUBTEXT_SIZE: 12,
NODE_DEFAULT_COLOR: "#333",
NODE_DEFAULT_BGCOLOR: "#353535",
NODE_DEFAULT_BOXCOLOR: "#666",
NODE_DEFAULT_SHAPE: "box",
NODE_BOX_OUTLINE_COLOR: "#FFF",
NODE_BYPASS_BGCOLOR: "#FF00FF",
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
DEFAULT_GROUP_FONT: 24,
WIDGET_BGCOLOR: "#222",
WIDGET_OUTLINE_COLOR: "#666",
WIDGET_TEXT_COLOR: "#DDD",
WIDGET_SECONDARY_TEXT_COLOR: "#999",
LINK_COLOR: "#9A9",
EVENT_LINK_COLOR: "#A86",
CONNECTING_LINK_COLOR: "#AFA",
BADGE_FG_COLOR: "#FFF",
BADGE_BG_COLOR: "#0F1F0F"
},
comfy_base: {
"fg-color": "#fff",
"bg-color": "#202020",
"comfy-menu-bg": "#353535",
"comfy-input-bg": "#222",
"input-text": "#ddd",
"descrip-text": "#999",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#4e4e4e",
"tr-even-bg-color": "#222",
"tr-odd-bg-color": "#353535",
"content-bg": "#4e4e4e",
"content-fg": "#fff",
"content-hover-bg": "#222",
"content-hover-fg": "#fff"
}
}
},
light: {
id: "light",
name: "Light",
colors: {
node_slot: {
CLIP: "#FFA726",
// orange
CLIP_VISION: "#5C6BC0",
// indigo
CLIP_VISION_OUTPUT: "#8D6E63",
// brown
CONDITIONING: "#EF5350",
// red
CONTROL_NET: "#66BB6A",
// green
IMAGE: "#42A5F5",
// blue
LATENT: "#AB47BC",
// purple
MASK: "#9CCC65",
// light green
MODEL: "#7E57C2",
// deep purple
STYLE_MODEL: "#D4E157",
// lime
VAE: "#FF7043"
// deep orange
},
litegraph_base: {
BACKGROUND_IMAGE: "",
CLEAR_BACKGROUND_COLOR: "lightgray",
NODE_TITLE_COLOR: "#222",
NODE_SELECTED_TITLE_COLOR: "#000",
NODE_TEXT_SIZE: 14,
NODE_TEXT_COLOR: "#444",
NODE_SUBTEXT_SIZE: 12,
NODE_DEFAULT_COLOR: "#F7F7F7",
NODE_DEFAULT_BGCOLOR: "#F5F5F5",
NODE_DEFAULT_BOXCOLOR: "#CCC",
NODE_DEFAULT_SHAPE: "box",
NODE_BOX_OUTLINE_COLOR: "#000",
NODE_BYPASS_BGCOLOR: "#FF00FF",
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.1)",
DEFAULT_GROUP_FONT: 24,
WIDGET_BGCOLOR: "#D4D4D4",
WIDGET_OUTLINE_COLOR: "#999",
WIDGET_TEXT_COLOR: "#222",
WIDGET_SECONDARY_TEXT_COLOR: "#555",
LINK_COLOR: "#4CAF50",
EVENT_LINK_COLOR: "#FF9800",
CONNECTING_LINK_COLOR: "#2196F3",
BADGE_FG_COLOR: "#000",
BADGE_BG_COLOR: "#FFF"
},
comfy_base: {
"fg-color": "#222",
"bg-color": "#DDD",
"comfy-menu-bg": "#F5F5F5",
"comfy-input-bg": "#C9C9C9",
"input-text": "#222",
"descrip-text": "#444",
"drag-text": "#555",
"error-text": "#F44336",
"border-color": "#888",
"tr-even-bg-color": "#f9f9f9",
"tr-odd-bg-color": "#fff",
"content-bg": "#e0e0e0",
"content-fg": "#222",
"content-hover-bg": "#adadad",
"content-hover-fg": "#222"
}
}
},
solarized: {
id: "solarized",
name: "Solarized",
colors: {
node_slot: {
CLIP: "#2AB7CA",
// light blue
CLIP_VISION: "#6c71c4",
// blue violet
CLIP_VISION_OUTPUT: "#859900",
// olive green
CONDITIONING: "#d33682",
// magenta
CONTROL_NET: "#d1ffd7",
// light mint green
IMAGE: "#5940bb",
// deep blue violet
LATENT: "#268bd2",
// blue
MASK: "#CCC9E7",
// light purple-gray
MODEL: "#dc322f",
// red
STYLE_MODEL: "#1a998a",
// teal
UPSCALE_MODEL: "#054A29",
// dark green
VAE: "#facfad"
// light pink-orange
},
litegraph_base: {
NODE_TITLE_COLOR: "#fdf6e3",
// Base3
NODE_SELECTED_TITLE_COLOR: "#A9D400",
NODE_TEXT_SIZE: 14,
NODE_TEXT_COLOR: "#657b83",
// Base00
NODE_SUBTEXT_SIZE: 12,
NODE_DEFAULT_COLOR: "#094656",
NODE_DEFAULT_BGCOLOR: "#073642",
// Base02
NODE_DEFAULT_BOXCOLOR: "#839496",
// Base0
NODE_DEFAULT_SHAPE: "box",
NODE_BOX_OUTLINE_COLOR: "#fdf6e3",
// Base3
NODE_BYPASS_BGCOLOR: "#FF00FF",
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
DEFAULT_GROUP_FONT: 24,
WIDGET_BGCOLOR: "#002b36",
// Base03
WIDGET_OUTLINE_COLOR: "#839496",
// Base0
WIDGET_TEXT_COLOR: "#fdf6e3",
// Base3
WIDGET_SECONDARY_TEXT_COLOR: "#93a1a1",
// Base1
LINK_COLOR: "#2aa198",
// Solarized Cyan
EVENT_LINK_COLOR: "#268bd2",
// Solarized Blue
CONNECTING_LINK_COLOR: "#859900"
// Solarized Green
},
comfy_base: {
"fg-color": "#fdf6e3",
// Base3
"bg-color": "#002b36",
// Base03
"comfy-menu-bg": "#073642",
// Base02
"comfy-input-bg": "#002b36",
// Base03
"input-text": "#93a1a1",
// Base1
"descrip-text": "#586e75",
// Base01
"drag-text": "#839496",
// Base0
"error-text": "#dc322f",
// Solarized Red
"border-color": "#657b83",
// Base00
"tr-even-bg-color": "#002b36",
"tr-odd-bg-color": "#073642",
"content-bg": "#657b83",
"content-fg": "#fdf6e3",
"content-hover-bg": "#002b36",
"content-hover-fg": "#fdf6e3"
}
}
},
arc: {
id: "arc",
name: "Arc",
colors: {
node_slot: {
BOOLEAN: "",
CLIP: "#eacb8b",
CLIP_VISION: "#A8DADC",
CLIP_VISION_OUTPUT: "#ad7452",
CONDITIONING: "#cf876f",
CONTROL_NET: "#00d78d",
CONTROL_NET_WEIGHTS: "",
FLOAT: "",
GLIGEN: "",
IMAGE: "#80a1c0",
IMAGEUPLOAD: "",
INT: "",
LATENT: "#b38ead",
LATENT_KEYFRAME: "",
MASK: "#a3bd8d",
MODEL: "#8978a7",
SAMPLER: "",
SIGMAS: "",
STRING: "",
STYLE_MODEL: "#C2FFAE",
T2I_ADAPTER_WEIGHTS: "",
TAESD: "#DCC274",
TIMESTEP_KEYFRAME: "",
UPSCALE_MODEL: "",
VAE: "#be616b"
},
litegraph_base: {
BACKGROUND_IMAGE: "",
CLEAR_BACKGROUND_COLOR: "#2b2f38",
NODE_TITLE_COLOR: "#b2b7bd",
NODE_SELECTED_TITLE_COLOR: "#FFF",
NODE_TEXT_SIZE: 14,
NODE_TEXT_COLOR: "#AAA",
NODE_SUBTEXT_SIZE: 12,
NODE_DEFAULT_COLOR: "#2b2f38",
NODE_DEFAULT_BGCOLOR: "#242730",
NODE_DEFAULT_BOXCOLOR: "#6e7581",
NODE_DEFAULT_SHAPE: "box",
NODE_BOX_OUTLINE_COLOR: "#FFF",
NODE_BYPASS_BGCOLOR: "#FF00FF",
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
DEFAULT_GROUP_FONT: 22,
WIDGET_BGCOLOR: "#2b2f38",
WIDGET_OUTLINE_COLOR: "#6e7581",
WIDGET_TEXT_COLOR: "#DDD",
WIDGET_SECONDARY_TEXT_COLOR: "#b2b7bd",
LINK_COLOR: "#9A9",
EVENT_LINK_COLOR: "#A86",
CONNECTING_LINK_COLOR: "#AFA"
},
comfy_base: {
"fg-color": "#fff",
"bg-color": "#2b2f38",
"comfy-menu-bg": "#242730",
"comfy-input-bg": "#2b2f38",
"input-text": "#ddd",
"descrip-text": "#b2b7bd",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#6e7581",
"tr-even-bg-color": "#2b2f38",
"tr-odd-bg-color": "#242730",
"content-bg": "#6e7581",
"content-fg": "#fff",
"content-hover-bg": "#2b2f38",
"content-hover-fg": "#fff"
}
}
},
nord: {
id: "nord",
name: "Nord",
colors: {
node_slot: {
BOOLEAN: "",
CLIP: "#eacb8b",
CLIP_VISION: "#A8DADC",
CLIP_VISION_OUTPUT: "#ad7452",
CONDITIONING: "#cf876f",
CONTROL_NET: "#00d78d",
CONTROL_NET_WEIGHTS: "",
FLOAT: "",
GLIGEN: "",
IMAGE: "#80a1c0",
IMAGEUPLOAD: "",
INT: "",
LATENT: "#b38ead",
LATENT_KEYFRAME: "",
MASK: "#a3bd8d",
MODEL: "#8978a7",
SAMPLER: "",
SIGMAS: "",
STRING: "",
STYLE_MODEL: "#C2FFAE",
T2I_ADAPTER_WEIGHTS: "",
TAESD: "#DCC274",
TIMESTEP_KEYFRAME: "",
UPSCALE_MODEL: "",
VAE: "#be616b"
},
litegraph_base: {
BACKGROUND_IMAGE: "",
CLEAR_BACKGROUND_COLOR: "#212732",
NODE_TITLE_COLOR: "#999",
NODE_SELECTED_TITLE_COLOR: "#e5eaf0",
NODE_TEXT_SIZE: 14,
NODE_TEXT_COLOR: "#bcc2c8",
NODE_SUBTEXT_SIZE: 12,
NODE_DEFAULT_COLOR: "#2e3440",
NODE_DEFAULT_BGCOLOR: "#161b22",
NODE_DEFAULT_BOXCOLOR: "#545d70",
NODE_DEFAULT_SHAPE: "box",
NODE_BOX_OUTLINE_COLOR: "#e5eaf0",
NODE_BYPASS_BGCOLOR: "#FF00FF",
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
DEFAULT_GROUP_FONT: 24,
WIDGET_BGCOLOR: "#2e3440",
WIDGET_OUTLINE_COLOR: "#545d70",
WIDGET_TEXT_COLOR: "#bcc2c8",
WIDGET_SECONDARY_TEXT_COLOR: "#999",
LINK_COLOR: "#9A9",
EVENT_LINK_COLOR: "#A86",
CONNECTING_LINK_COLOR: "#AFA"
},
comfy_base: {
"fg-color": "#e5eaf0",
"bg-color": "#2e3440",
"comfy-menu-bg": "#161b22",
"comfy-input-bg": "#2e3440",
"input-text": "#bcc2c8",
"descrip-text": "#999",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#545d70",
"tr-even-bg-color": "#2e3440",
"tr-odd-bg-color": "#161b22",
"content-bg": "#545d70",
"content-fg": "#e5eaf0",
"content-hover-bg": "#2e3440",
"content-hover-fg": "#e5eaf0"
}
}
},
github: {
id: "github",
name: "Github",
colors: {
node_slot: {
BOOLEAN: "",
CLIP: "#eacb8b",
CLIP_VISION: "#A8DADC",
CLIP_VISION_OUTPUT: "#ad7452",
CONDITIONING: "#cf876f",
CONTROL_NET: "#00d78d",
CONTROL_NET_WEIGHTS: "",
FLOAT: "",
GLIGEN: "",
IMAGE: "#80a1c0",
IMAGEUPLOAD: "",
INT: "",
LATENT: "#b38ead",
LATENT_KEYFRAME: "",
MASK: "#a3bd8d",
MODEL: "#8978a7",
SAMPLER: "",
SIGMAS: "",
STRING: "",
STYLE_MODEL: "#C2FFAE",
T2I_ADAPTER_WEIGHTS: "",
TAESD: "#DCC274",
TIMESTEP_KEYFRAME: "",
UPSCALE_MODEL: "",
VAE: "#be616b"
},
litegraph_base: {
BACKGROUND_IMAGE: "",
CLEAR_BACKGROUND_COLOR: "#040506",
NODE_TITLE_COLOR: "#999",
NODE_SELECTED_TITLE_COLOR: "#e5eaf0",
NODE_TEXT_SIZE: 14,
NODE_TEXT_COLOR: "#bcc2c8",
NODE_SUBTEXT_SIZE: 12,
NODE_DEFAULT_COLOR: "#161b22",
NODE_DEFAULT_BGCOLOR: "#13171d",
NODE_DEFAULT_BOXCOLOR: "#30363d",
NODE_DEFAULT_SHAPE: "box",
NODE_BOX_OUTLINE_COLOR: "#e5eaf0",
NODE_BYPASS_BGCOLOR: "#FF00FF",
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
DEFAULT_GROUP_FONT: 24,
WIDGET_BGCOLOR: "#161b22",
WIDGET_OUTLINE_COLOR: "#30363d",
WIDGET_TEXT_COLOR: "#bcc2c8",
WIDGET_SECONDARY_TEXT_COLOR: "#999",
LINK_COLOR: "#9A9",
EVENT_LINK_COLOR: "#A86",
CONNECTING_LINK_COLOR: "#AFA"
},
comfy_base: {
"fg-color": "#e5eaf0",
"bg-color": "#161b22",
"comfy-menu-bg": "#13171d",
"comfy-input-bg": "#161b22",
"input-text": "#bcc2c8",
"descrip-text": "#999",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#30363d",
"tr-even-bg-color": "#161b22",
"tr-odd-bg-color": "#13171d",
"content-bg": "#30363d",
"content-fg": "#e5eaf0",
"content-hover-bg": "#161b22",
"content-hover-fg": "#e5eaf0"
}
}
}
};
const id = "Comfy.ColorPalette";
const idCustomColorPalettes = "Comfy.CustomColorPalettes";
const defaultColorPaletteId = "dark";
const els = {
select: null
};
const getCustomColorPalettes = /* @__PURE__ */ __name(() => {
return app.ui.settings.getSettingValue(idCustomColorPalettes, {});
}, "getCustomColorPalettes");
const setCustomColorPalettes = /* @__PURE__ */ __name((customColorPalettes) => {
return app.ui.settings.setSettingValue(
idCustomColorPalettes,
customColorPalettes
);
}, "setCustomColorPalettes");
const defaultColorPalette = colorPalettes[defaultColorPaletteId];
const getColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
if (!colorPaletteId) {
colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
}
if (colorPaletteId.startsWith("custom_")) {
colorPaletteId = colorPaletteId.substr(7);
let customColorPalettes = getCustomColorPalettes();
if (customColorPalettes[colorPaletteId]) {
return customColorPalettes[colorPaletteId];
}
}
return colorPalettes[colorPaletteId];
}, "getColorPalette");
const setColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
app.ui.settings.setSettingValue(id, colorPaletteId);
}, "setColorPalette");
app.registerExtension({
name: id,
init() {
LGraphCanvas.prototype.updateBackground = function(image, clearBackgroundColor) {
this._bg_img = new Image();
this._bg_img.name = image;
this._bg_img.src = image;
this._bg_img.onload = () => {
this.draw(true, true);
};
this.background_image = image;
this.clear_background = true;
this.clear_background_color = clearBackgroundColor;
this._pattern = null;
};
},
addCustomNodeDefs(node_defs) {
const sortObjectKeys = /* @__PURE__ */ __name((unordered) => {
return Object.keys(unordered).sort().reduce((obj, key) => {
obj[key] = unordered[key];
return obj;
}, {});
}, "sortObjectKeys");
function getSlotTypes() {
var types = [];
const defs = node_defs;
for (const nodeId in defs) {
const nodeData = defs[nodeId];
var inputs = nodeData["input"]["required"];
if (nodeData["input"]["optional"] !== void 0) {
inputs = Object.assign(
{},
nodeData["input"]["required"],
nodeData["input"]["optional"]
);
}
for (const inputName in inputs) {
const inputData = inputs[inputName];
const type = inputData[0];
if (!Array.isArray(type)) {
types.push(type);
}
}
for (const o in nodeData["output"]) {
const output = nodeData["output"][o];
types.push(output);
}
}
return types;
}
__name(getSlotTypes, "getSlotTypes");
function completeColorPalette(colorPalette) {
var types = getSlotTypes();
for (const type of types) {
if (!colorPalette.colors.node_slot[type]) {
colorPalette.colors.node_slot[type] = "";
}
}
colorPalette.colors.node_slot = sortObjectKeys(
colorPalette.colors.node_slot
);
return colorPalette;
}
__name(completeColorPalette, "completeColorPalette");
const getColorPaletteTemplate = /* @__PURE__ */ __name(async () => {
let colorPalette = {
id: "my_color_palette_unique_id",
name: "My Color Palette",
colors: {
node_slot: {},
litegraph_base: {},
comfy_base: {}
}
};
const defaultColorPalette2 = colorPalettes[defaultColorPaletteId];
for (const key in defaultColorPalette2.colors.litegraph_base) {
if (!colorPalette.colors.litegraph_base[key]) {
colorPalette.colors.litegraph_base[key] = "";
}
}
for (const key in defaultColorPalette2.colors.comfy_base) {
if (!colorPalette.colors.comfy_base[key]) {
colorPalette.colors.comfy_base[key] = "";
}
}
return completeColorPalette(colorPalette);
}, "getColorPaletteTemplate");
const addCustomColorPalette = /* @__PURE__ */ __name(async (colorPalette) => {
if (typeof colorPalette !== "object") {
useToastStore().addAlert("Invalid color palette.");
return;
}
if (!colorPalette.id) {
useToastStore().addAlert("Color palette missing id.");
return;
}
if (!colorPalette.name) {
useToastStore().addAlert("Color palette missing name.");
return;
}
if (!colorPalette.colors) {
useToastStore().addAlert("Color palette missing colors.");
return;
}
if (colorPalette.colors.node_slot && typeof colorPalette.colors.node_slot !== "object") {
useToastStore().addAlert("Invalid color palette colors.node_slot.");
return;
}
const customColorPalettes = getCustomColorPalettes();
customColorPalettes[colorPalette.id] = colorPalette;
setCustomColorPalettes(customColorPalettes);
for (const option of els.select.childNodes) {
if (option.value === "custom_" + colorPalette.id) {
els.select.removeChild(option);
}
}
els.select.append(
$el("option", {
textContent: colorPalette.name + " (custom)",
value: "custom_" + colorPalette.id,
selected: true
})
);
setColorPalette("custom_" + colorPalette.id);
await loadColorPalette(colorPalette);
}, "addCustomColorPalette");
const deleteCustomColorPalette = /* @__PURE__ */ __name(async (colorPaletteId) => {
const customColorPalettes = getCustomColorPalettes();
delete customColorPalettes[colorPaletteId];
setCustomColorPalettes(customColorPalettes);
for (const opt of els.select.childNodes) {
const option = opt;
if (option.value === defaultColorPaletteId) {
option.selected = true;
}
if (option.value === "custom_" + colorPaletteId) {
els.select.removeChild(option);
}
}
setColorPalette(defaultColorPaletteId);
await loadColorPalette(getColorPalette());
}, "deleteCustomColorPalette");
const loadColorPalette = /* @__PURE__ */ __name(async (colorPalette) => {
colorPalette = await completeColorPalette(colorPalette);
if (colorPalette.colors) {
if (colorPalette.colors.node_slot) {
Object.assign(
app.canvas.default_connection_color_byType,
colorPalette.colors.node_slot
);
Object.assign(
LGraphCanvas.link_type_colors,
colorPalette.colors.node_slot
);
}
if (colorPalette.colors.litegraph_base) {
app.canvas.node_title_color = colorPalette.colors.litegraph_base.NODE_TITLE_COLOR;
app.canvas.default_link_color = colorPalette.colors.litegraph_base.LINK_COLOR;
for (const key in colorPalette.colors.litegraph_base) {
if (colorPalette.colors.litegraph_base.hasOwnProperty(key) && LiteGraph.hasOwnProperty(key)) {
LiteGraph[key] = colorPalette.colors.litegraph_base[key];
}
}
}
if (colorPalette.colors.comfy_base) {
const rootStyle = document.documentElement.style;
for (const key in colorPalette.colors.comfy_base) {
rootStyle.setProperty(
"--" + key,
colorPalette.colors.comfy_base[key]
);
}
}
if (colorPalette.colors.litegraph_base.NODE_BYPASS_BGCOLOR) {
app.bypassBgColor = colorPalette.colors.litegraph_base.NODE_BYPASS_BGCOLOR;
}
app.canvas.draw(true, true);
}
}, "loadColorPalette");
const fileInput = $el("input", {
type: "file",
accept: ".json",
style: { display: "none" },
parent: document.body,
onchange: /* @__PURE__ */ __name(() => {
const file = fileInput.files[0];
if (file.type === "application/json" || file.name.endsWith(".json")) {
const reader = new FileReader();
reader.onload = async () => {
await addCustomColorPalette(JSON.parse(reader.result));
};
reader.readAsText(file);
}
}, "onchange")
});
app.ui.settings.addSetting({
id,
category: ["Comfy", "ColorPalette"],
name: "Color Palette",
type: /* @__PURE__ */ __name((name, setter, value) => {
const options = [
...Object.values(colorPalettes).map(
(c) => $el("option", {
textContent: c.name,
value: c.id,
selected: c.id === value
})
),
...Object.values(getCustomColorPalettes()).map(
(c) => $el("option", {
textContent: `${c.name} (custom)`,
value: `custom_${c.id}`,
selected: `custom_${c.id}` === value
})
)
];
els.select = $el(
"select",
{
style: {
marginBottom: "0.15rem",
width: "100%"
},
onchange: /* @__PURE__ */ __name((e) => {
setter(e.target.value);
}, "onchange")
},
options
);
return $el("tr", [
$el("td", [
els.select,
$el(
"div",
{
style: {
display: "grid",
gap: "4px",
gridAutoFlow: "column"
}
},
[
$el("input", {
type: "button",
value: "Export",
onclick: /* @__PURE__ */ __name(async () => {
const colorPaletteId = app.ui.settings.getSettingValue(
id,
defaultColorPaletteId
);
const colorPalette = await completeColorPalette(
getColorPalette(colorPaletteId)
);
const json = JSON.stringify(colorPalette, null, 2);
const blob = new Blob([json], { type: "application/json" });
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: colorPaletteId + ".json",
style: { display: "none" },
parent: document.body
});
a.click();
setTimeout(function() {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
}, "onclick")
}),
$el("input", {
type: "button",
value: "Import",
onclick: /* @__PURE__ */ __name(() => {
fileInput.click();
}, "onclick")
}),
$el("input", {
type: "button",
value: "Template",
onclick: /* @__PURE__ */ __name(async () => {
const colorPalette = await getColorPaletteTemplate();
const json = JSON.stringify(colorPalette, null, 2);
const blob = new Blob([json], { type: "application/json" });
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: "color_palette.json",
style: { display: "none" },
parent: document.body
});
a.click();
setTimeout(function() {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
}, "onclick")
}),
$el("input", {
type: "button",
value: "Delete",
onclick: /* @__PURE__ */ __name(async () => {
let colorPaletteId = app.ui.settings.getSettingValue(
id,
defaultColorPaletteId
);
if (colorPalettes[colorPaletteId]) {
useToastStore().addAlert(
"You cannot delete a built-in color palette."
);
return;
}
if (colorPaletteId.startsWith("custom_")) {
colorPaletteId = colorPaletteId.substr(7);
}
await deleteCustomColorPalette(colorPaletteId);
}, "onclick")
})
]
)
])
]);
}, "type"),
defaultValue: defaultColorPaletteId,
async onChange(value) {
if (!value) {
return;
}
let palette = colorPalettes[value];
if (palette) {
await loadColorPalette(palette);
} else if (value.startsWith("custom_")) {
value = value.substr(7);
let customColorPalettes = getCustomColorPalettes();
if (customColorPalettes[value]) {
palette = customColorPalettes[value];
await loadColorPalette(customColorPalettes[value]);
}
}
let { BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR } = palette.colors.litegraph_base;
if (BACKGROUND_IMAGE === void 0 || CLEAR_BACKGROUND_COLOR === void 0) {
const base = colorPalettes["dark"].colors.litegraph_base;
BACKGROUND_IMAGE = base.BACKGROUND_IMAGE;
CLEAR_BACKGROUND_COLOR = base.CLEAR_BACKGROUND_COLOR;
}
app.canvas.updateBackground(BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR);
}
});
}
});
window.comfyAPI = window.comfyAPI || {};
window.comfyAPI.colorPalette = window.comfyAPI.colorPalette || {};
window.comfyAPI.colorPalette.defaultColorPalette = defaultColorPalette;
window.comfyAPI.colorPalette.getColorPalette = getColorPalette;
export {
defaultColorPalette as d,
getColorPalette as g
};
//# sourceMappingURL=colorPalette-D5oi2-2V.js.map

1
web/assets/colorPalette-D5oi2-2V.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

1
web/assets/index-BD-Ia1C4.js.map generated vendored

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

1
web/assets/index-BMC1ey-i.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

1
web/assets/index-CI3N807S.js.map generated vendored

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

1
web/assets/index-DGAbdBYF.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

2602
web/assets/sorted-custom-node-map.json generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,37 @@
.lds-ring {
display: inline-block;
position: relative;
width: 1em;
height: 1em;
}
.lds-ring div {
box-sizing: border-box;
display: block;
position: absolute;
width: 100%;
height: 100%;
border: 0.15em solid #fff;
border-radius: 50%;
animation: lds-ring 1.2s cubic-bezier(0.5, 0, 0.5, 1) infinite;
border-color: #fff transparent transparent transparent;
}
.lds-ring div:nth-child(1) {
animation-delay: -0.45s;
}
.lds-ring div:nth-child(2) {
animation-delay: -0.3s;
}
.lds-ring div:nth-child(3) {
animation-delay: -0.15s;
}
@keyframes lds-ring {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}
.comfy-user-selection {
width: 100vw;
height: 100vh;

File diff suppressed because one or more lines are too long

View File

@@ -1,6 +1,15 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { j as createSpinner, h as api, $ as $el } from "./index-CI3N807S.js";
import { b4 as api, ca as $el } from "./index-DGAbdBYF.js";
function createSpinner() {
const div = document.createElement("div");
div.innerHTML = `<div class="lds-ring"><div></div><div></div><div></div><div></div></div>`;
return div.firstElementChild;
}
__name(createSpinner, "createSpinner");
window.comfyAPI = window.comfyAPI || {};
window.comfyAPI.spinner = window.comfyAPI.spinner || {};
window.comfyAPI.spinner.createSpinner = createSpinner;
class UserSelectionScreen {
static {
__name(this, "UserSelectionScreen");
@@ -117,4 +126,4 @@ window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
export {
UserSelectionScreen
};
//# sourceMappingURL=userSelection-CyXKCVy3.js.map
//# sourceMappingURL=userSelection-Duxc-t_S.js.map

1
web/assets/userSelection-Duxc-t_S.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

756
web/assets/widgetInputs-DdoWwzg5.js generated vendored Normal file
View File

@@ -0,0 +1,756 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { l as LGraphNode, k as app, cf as applyTextReplacements, ce as ComfyWidgets, ci as addValueControlWidgets, z as LiteGraph } from "./index-DGAbdBYF.js";
const CONVERTED_TYPE = "converted-widget";
const VALID_TYPES = [
"STRING",
"combo",
"number",
"toggle",
"BOOLEAN",
"text",
"string"
];
const CONFIG = Symbol();
const GET_CONFIG = Symbol();
const TARGET = Symbol();
const replacePropertyName = "Run widget replace on values";
class PrimitiveNode extends LGraphNode {
static {
__name(this, "PrimitiveNode");
}
controlValues;
lastType;
static category;
constructor(title) {
super(title);
this.addOutput("connect to widget input", "*");
this.serialize_widgets = true;
this.isVirtualNode = true;
if (!this.properties || !(replacePropertyName in this.properties)) {
this.addProperty(replacePropertyName, false, "boolean");
}
}
applyToGraph(extraLinks = []) {
if (!this.outputs[0].links?.length) return;
function get_links(node) {
let links2 = [];
for (const l of node.outputs[0].links) {
const linkInfo = app.graph.links[l];
const n = node.graph.getNodeById(linkInfo.target_id);
if (n.type == "Reroute") {
links2 = links2.concat(get_links(n));
} else {
links2.push(l);
}
}
return links2;
}
__name(get_links, "get_links");
let links = [
...get_links(this).map((l) => app.graph.links[l]),
...extraLinks
];
let v = this.widgets?.[0].value;
if (v && this.properties[replacePropertyName]) {
v = applyTextReplacements(app, v);
}
for (const linkInfo of links) {
const node = this.graph.getNodeById(linkInfo.target_id);
const input = node.inputs[linkInfo.target_slot];
let widget;
if (input.widget[TARGET]) {
widget = input.widget[TARGET];
} else {
const widgetName = input.widget.name;
if (widgetName) {
widget = node.widgets.find((w) => w.name === widgetName);
}
}
if (widget) {
widget.value = v;
if (widget.callback) {
widget.callback(
widget.value,
app.canvas,
node,
app.canvas.graph_mouse,
{}
);
}
}
}
}
refreshComboInNode() {
const widget = this.widgets?.[0];
if (widget?.type === "combo") {
widget.options.values = this.outputs[0].widget[GET_CONFIG]()[0];
if (!widget.options.values.includes(widget.value)) {
widget.value = widget.options.values[0];
widget.callback(widget.value);
}
}
}
onAfterGraphConfigured() {
if (this.outputs[0].links?.length && !this.widgets?.length) {
if (!this.#onFirstConnection()) return;
if (this.widgets) {
for (let i = 0; i < this.widgets_values.length; i++) {
const w = this.widgets[i];
if (w) {
w.value = this.widgets_values[i];
}
}
}
this.#mergeWidgetConfig();
}
}
onConnectionsChange(_, index, connected) {
if (app.configuringGraph) {
return;
}
const links = this.outputs[0].links;
if (connected) {
if (links?.length && !this.widgets?.length) {
this.#onFirstConnection();
}
} else {
this.#mergeWidgetConfig();
if (!links?.length) {
this.onLastDisconnect();
}
}
}
onConnectOutput(slot, type, input, target_node, target_slot) {
if (!input.widget) {
if (!(input.type in ComfyWidgets)) return false;
}
if (this.outputs[slot].links?.length) {
const valid = this.#isValidConnection(input);
if (valid) {
this.applyToGraph([{ target_id: target_node.id, target_slot }]);
}
return valid;
}
}
#onFirstConnection(recreating) {
if (!this.outputs[0].links) {
this.onLastDisconnect();
return;
}
const linkId = this.outputs[0].links[0];
const link = this.graph.links[linkId];
if (!link) return;
const theirNode = this.graph.getNodeById(link.target_id);
if (!theirNode || !theirNode.inputs) return;
const input = theirNode.inputs[link.target_slot];
if (!input) return;
let widget;
if (!input.widget) {
if (!(input.type in ComfyWidgets)) return;
widget = { name: input.name, [GET_CONFIG]: () => [input.type, {}] };
} else {
widget = input.widget;
}
const config = widget[GET_CONFIG]?.();
if (!config) return;
const { type } = getWidgetType(config);
this.outputs[0].type = type;
this.outputs[0].name = type;
this.outputs[0].widget = widget;
this.#createWidget(
widget[CONFIG] ?? config,
theirNode,
widget.name,
recreating,
widget[TARGET]
);
}
#createWidget(inputData, node, widgetName, recreating, targetWidget) {
let type = inputData[0];
if (type instanceof Array) {
type = "COMBO";
}
const size = this.size;
let widget;
if (type in ComfyWidgets) {
widget = (ComfyWidgets[type](this, "value", inputData, app) || {}).widget;
} else {
widget = this.addWidget(type, "value", null, () => {
}, {});
}
if (targetWidget) {
widget.value = targetWidget.value;
} else if (node?.widgets && widget) {
const theirWidget = node.widgets.find((w) => w.name === widgetName);
if (theirWidget) {
widget.value = theirWidget.value;
}
}
if (!inputData?.[1]?.control_after_generate && (widget.type === "number" || widget.type === "combo")) {
let control_value = this.widgets_values?.[1];
if (!control_value) {
control_value = "fixed";
}
addValueControlWidgets(
this,
widget,
control_value,
void 0,
inputData
);
let filter = this.widgets_values?.[2];
if (filter && this.widgets.length === 3) {
this.widgets[2].value = filter;
}
}
const controlValues = this.controlValues;
if (this.lastType === this.widgets[0].type && controlValues?.length === this.widgets.length - 1) {
for (let i = 0; i < controlValues.length; i++) {
this.widgets[i + 1].value = controlValues[i];
}
}
const callback = widget.callback;
const self = this;
widget.callback = function() {
const r = callback ? callback.apply(this, arguments) : void 0;
self.applyToGraph();
return r;
};
this.size = [
Math.max(this.size[0], size[0]),
Math.max(this.size[1], size[1])
];
if (!recreating) {
const sz = this.computeSize();
if (this.size[0] < sz[0]) {
this.size[0] = sz[0];
}
if (this.size[1] < sz[1]) {
this.size[1] = sz[1];
}
requestAnimationFrame(() => {
if (this.onResize) {
this.onResize(this.size);
}
});
}
}
recreateWidget() {
const values = this.widgets?.map((w) => w.value);
this.#removeWidgets();
this.#onFirstConnection(true);
if (values?.length) {
for (let i = 0; i < this.widgets?.length; i++)
this.widgets[i].value = values[i];
}
return this.widgets?.[0];
}
#mergeWidgetConfig() {
const output = this.outputs[0];
const links = output.links;
const hasConfig = !!output.widget[CONFIG];
if (hasConfig) {
delete output.widget[CONFIG];
}
if (links?.length < 2 && hasConfig) {
if (links.length) {
this.recreateWidget();
}
return;
}
const config1 = output.widget[GET_CONFIG]();
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
if (!isNumber) return;
for (const linkId of links) {
const link = app.graph.links[linkId];
if (!link) continue;
const theirNode = app.graph.getNodeById(link.target_id);
const theirInput = theirNode.inputs[link.target_slot];
this.#isValidConnection(theirInput, hasConfig);
}
}
isValidWidgetLink(originSlot, targetNode, targetWidget) {
const config2 = getConfig.call(targetNode, targetWidget.name) ?? [
targetWidget.type,
targetWidget.options || {}
];
if (!isConvertibleWidget(targetWidget, config2)) return false;
const output = this.outputs[originSlot];
if (!(output.widget?.[CONFIG] ?? output.widget?.[GET_CONFIG]())) {
return true;
}
return !!mergeIfValid.call(this, output, config2);
}
#isValidConnection(input, forceUpdate) {
const output = this.outputs[0];
const config2 = input.widget[GET_CONFIG]();
return !!mergeIfValid.call(
this,
output,
config2,
forceUpdate,
this.recreateWidget
);
}
#removeWidgets() {
if (this.widgets) {
for (const w of this.widgets) {
if (w.onRemove) {
w.onRemove();
}
}
this.controlValues = [];
this.lastType = this.widgets[0]?.type;
for (let i = 1; i < this.widgets.length; i++) {
this.controlValues.push(this.widgets[i].value);
}
setTimeout(() => {
delete this.lastType;
delete this.controlValues;
}, 15);
this.widgets.length = 0;
}
}
onLastDisconnect() {
this.outputs[0].type = "*";
this.outputs[0].name = "connect to widget input";
delete this.outputs[0].widget;
this.#removeWidgets();
}
}
function getWidgetConfig(slot) {
return slot.widget[CONFIG] ?? slot.widget[GET_CONFIG]();
}
__name(getWidgetConfig, "getWidgetConfig");
function getConfig(widgetName) {
const { nodeData } = this.constructor;
return nodeData?.input?.required?.[widgetName] ?? nodeData?.input?.optional?.[widgetName];
}
__name(getConfig, "getConfig");
function isConvertibleWidget(widget, config) {
return (VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0])) && !widget.options?.forceInput;
}
__name(isConvertibleWidget, "isConvertibleWidget");
function hideWidget(node, widget, suffix = "") {
if (widget.type?.startsWith(CONVERTED_TYPE)) return;
widget.origType = widget.type;
widget.origComputeSize = widget.computeSize;
widget.origSerializeValue = widget.serializeValue;
widget.computeSize = () => [0, -4];
widget.type = CONVERTED_TYPE + suffix;
widget.serializeValue = () => {
if (!node.inputs) {
return void 0;
}
let node_input = node.inputs.find((i) => i.widget?.name === widget.name);
if (!node_input || !node_input.link) {
return void 0;
}
return widget.origSerializeValue ? widget.origSerializeValue() : widget.value;
};
if (widget.linkedWidgets) {
for (const w of widget.linkedWidgets) {
hideWidget(node, w, ":" + widget.name);
}
}
}
__name(hideWidget, "hideWidget");
function showWidget(widget) {
widget.type = widget.origType;
widget.computeSize = widget.origComputeSize;
widget.serializeValue = widget.origSerializeValue;
delete widget.origType;
delete widget.origComputeSize;
delete widget.origSerializeValue;
if (widget.linkedWidgets) {
for (const w of widget.linkedWidgets) {
showWidget(w);
}
}
}
__name(showWidget, "showWidget");
function convertToInput(node, widget, config) {
hideWidget(node, widget);
const { type } = getWidgetType(config);
const sz = node.size;
const inputIsOptional = !!widget.options?.inputIsOptional;
const input = node.addInput(widget.name, type, {
widget: { name: widget.name, [GET_CONFIG]: () => config },
...inputIsOptional ? { shape: LiteGraph.SlotShape.HollowCircle } : {}
});
for (const widget2 of node.widgets) {
widget2.last_y += LiteGraph.NODE_SLOT_HEIGHT;
}
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]);
return input;
}
__name(convertToInput, "convertToInput");
function convertToWidget(node, widget) {
showWidget(widget);
const sz = node.size;
node.removeInput(node.inputs.findIndex((i) => i.widget?.name === widget.name));
for (const widget2 of node.widgets) {
widget2.last_y -= LiteGraph.NODE_SLOT_HEIGHT;
}
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]);
}
__name(convertToWidget, "convertToWidget");
function getWidgetType(config) {
let type = config[0];
if (type instanceof Array) {
type = "COMBO";
}
return { type };
}
__name(getWidgetType, "getWidgetType");
function isValidCombo(combo, obj) {
if (!(obj instanceof Array)) {
console.log(`connection rejected: tried to connect combo to ${obj}`);
return false;
}
if (combo.length !== obj.length) {
console.log(`connection rejected: combo lists dont match`);
return false;
}
if (combo.find((v, i) => obj[i] !== v)) {
console.log(`connection rejected: combo lists dont match`);
return false;
}
return true;
}
__name(isValidCombo, "isValidCombo");
function isPrimitiveNode(node) {
return node.type === "PrimitiveNode";
}
__name(isPrimitiveNode, "isPrimitiveNode");
function setWidgetConfig(slot, config, target) {
if (!slot.widget) return;
if (config) {
slot.widget[GET_CONFIG] = () => config;
slot.widget[TARGET] = target;
} else {
delete slot.widget;
}
if (slot.link) {
const link = app.graph.links[slot.link];
if (link) {
const originNode = app.graph.getNodeById(link.origin_id);
if (isPrimitiveNode(originNode)) {
if (config) {
originNode.recreateWidget();
} else if (!app.configuringGraph) {
originNode.disconnectOutput(0);
originNode.onLastDisconnect();
}
}
}
}
}
__name(setWidgetConfig, "setWidgetConfig");
function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
if (!config1) {
config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
}
if (config1[0] instanceof Array) {
if (!isValidCombo(config1[0], config2[0])) return;
} else if (config1[0] !== config2[0]) {
console.log(`connection rejected: types dont match`, config1[0], config2[0]);
return;
}
const keys = /* @__PURE__ */ new Set([
...Object.keys(config1[1] ?? {}),
...Object.keys(config2[1] ?? {})
]);
let customConfig;
const getCustomConfig = /* @__PURE__ */ __name(() => {
if (!customConfig) {
if (typeof structuredClone === "undefined") {
customConfig = JSON.parse(JSON.stringify(config1[1] ?? {}));
} else {
customConfig = structuredClone(config1[1] ?? {});
}
}
return customConfig;
}, "getCustomConfig");
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
for (const k of keys.values()) {
if (k !== "default" && k !== "forceInput" && k !== "defaultInput" && k !== "control_after_generate" && k !== "multiline" && k !== "tooltip") {
let v1 = config1[1][k];
let v2 = config2[1]?.[k];
if (v1 === v2 || !v1 && !v2) continue;
if (isNumber) {
if (k === "min") {
const theirMax = config2[1]?.["max"];
if (theirMax != null && v1 > theirMax) {
console.log("connection rejected: min > max", v1, theirMax);
return;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2);
continue;
} else if (k === "max") {
const theirMin = config2[1]?.["min"];
if (theirMin != null && v1 < theirMin) {
console.log("connection rejected: max < min", v1, theirMin);
return;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2);
continue;
} else if (k === "step") {
let step;
if (v1 == null) {
step = v2;
} else if (v2 == null) {
step = v1;
} else {
if (v1 < v2) {
const a = v2;
v2 = v1;
v1 = a;
}
if (v1 % v2) {
console.log(
"connection rejected: steps not divisible",
"current:",
v1,
"new:",
v2
);
return;
}
step = v1;
}
getCustomConfig()[k] = step;
continue;
}
}
console.log(`connection rejected: config ${k} values dont match`, v1, v2);
return;
}
}
if (customConfig || forceUpdate) {
if (customConfig) {
output.widget[CONFIG] = [config1[0], customConfig];
}
const widget = recreateWidget?.call(this);
if (widget) {
const min = widget.options.min;
const max = widget.options.max;
if (min != null && widget.value < min) widget.value = min;
if (max != null && widget.value > max) widget.value = max;
widget.callback(widget.value);
}
}
return { customConfig };
}
__name(mergeIfValid, "mergeIfValid");
let useConversionSubmenusSetting;
app.registerExtension({
name: "Comfy.WidgetInputs",
init() {
useConversionSubmenusSetting = app.ui.settings.addSetting({
id: "Comfy.NodeInputConversionSubmenus",
name: "In the node context menu, place the entries that convert between input/widget in sub-menus.",
type: "boolean",
defaultValue: true
});
},
async beforeRegisterNodeDef(nodeType, nodeData, app2) {
const origGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions;
nodeType.prototype.convertWidgetToInput = function(widget) {
const config = getConfig.call(this, widget.name) ?? [
widget.type,
widget.options || {}
];
if (!isConvertibleWidget(widget, config)) return false;
if (widget.type?.startsWith(CONVERTED_TYPE)) return false;
convertToInput(this, widget, config);
return true;
};
nodeType.prototype.getExtraMenuOptions = function(_, options) {
const r = origGetExtraMenuOptions ? origGetExtraMenuOptions.apply(this, arguments) : void 0;
if (this.widgets) {
let toInput = [];
let toWidget = [];
for (const w of this.widgets) {
if (w.options?.forceInput) {
continue;
}
if (w.type === CONVERTED_TYPE) {
toWidget.push({
content: `Convert ${w.name} to widget`,
callback: /* @__PURE__ */ __name(() => convertToWidget(this, w), "callback")
});
} else {
const config = getConfig.call(this, w.name) ?? [
w.type,
w.options || {}
];
if (isConvertibleWidget(w, config)) {
toInput.push({
content: `Convert ${w.name} to input`,
callback: /* @__PURE__ */ __name(() => convertToInput(this, w, config), "callback")
});
}
}
}
if (toInput.length) {
if (useConversionSubmenusSetting.value) {
options.push({
content: "Convert Widget to Input",
submenu: {
options: toInput
}
});
} else {
options.push(...toInput, null);
}
}
if (toWidget.length) {
if (useConversionSubmenusSetting.value) {
options.push({
content: "Convert Input to Widget",
submenu: {
options: toWidget
}
});
} else {
options.push(...toWidget, null);
}
}
}
return r;
};
nodeType.prototype.onGraphConfigured = function() {
if (!this.inputs) return;
this.widgets ??= [];
for (const input of this.inputs) {
if (input.widget) {
if (!input.widget[GET_CONFIG]) {
input.widget[GET_CONFIG] = () => getConfig.call(this, input.widget.name);
}
if (input.widget.config) {
if (input.widget.config[0] instanceof Array) {
input.type = "COMBO";
const link = app2.graph.links[input.link];
if (link) {
link.type = input.type;
}
}
delete input.widget.config;
}
const w = this.widgets.find((w2) => w2.name === input.widget.name);
if (w) {
hideWidget(this, w);
} else {
convertToWidget(this, input);
}
}
}
};
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = function() {
const r = origOnNodeCreated ? origOnNodeCreated.apply(this) : void 0;
if (!app2.configuringGraph && this.widgets) {
for (const w of this.widgets) {
if (w?.options?.forceInput || w?.options?.defaultInput) {
const config = getConfig.call(this, w.name) ?? [
w.type,
w.options || {}
];
convertToInput(this, w, config);
}
}
}
return r;
};
const origOnConfigure = nodeType.prototype.onConfigure;
nodeType.prototype.onConfigure = function() {
const r = origOnConfigure ? origOnConfigure.apply(this, arguments) : void 0;
if (!app2.configuringGraph && this.inputs) {
for (const input of this.inputs) {
if (input.widget && !input.widget[GET_CONFIG]) {
input.widget[GET_CONFIG] = () => getConfig.call(this, input.widget.name);
const w = this.widgets.find((w2) => w2.name === input.widget.name);
if (w) {
hideWidget(this, w);
}
}
}
}
return r;
};
function isNodeAtPos(pos) {
for (const n of app2.graph.nodes) {
if (n.pos[0] === pos[0] && n.pos[1] === pos[1]) {
return true;
}
}
return false;
}
__name(isNodeAtPos, "isNodeAtPos");
const origOnInputDblClick = nodeType.prototype.onInputDblClick;
const ignoreDblClick = Symbol();
nodeType.prototype.onInputDblClick = function(slot) {
const r = origOnInputDblClick ? origOnInputDblClick.apply(this, arguments) : void 0;
const input = this.inputs[slot];
if (!input.widget || !input[ignoreDblClick]) {
if (!(input.type in ComfyWidgets) && !(input.widget[GET_CONFIG]?.()?.[0] instanceof Array)) {
return r;
}
}
const node = LiteGraph.createNode("PrimitiveNode");
app2.graph.add(node);
const pos = [
this.pos[0] - node.size[0] - 30,
this.pos[1]
];
while (isNodeAtPos(pos)) {
pos[1] += LiteGraph.NODE_TITLE_HEIGHT;
}
node.pos = pos;
node.connect(0, this, slot);
node.title = input.name;
input[ignoreDblClick] = true;
setTimeout(() => {
delete input[ignoreDblClick];
}, 300);
return r;
};
const onConnectInput = nodeType.prototype.onConnectInput;
nodeType.prototype.onConnectInput = function(targetSlot, type, output, originNode, originSlot) {
const v = onConnectInput?.(this, arguments);
if (type !== "COMBO") return v;
if (originNode.outputs[originSlot].widget) return v;
const targetCombo = this.inputs[targetSlot].widget?.[GET_CONFIG]?.()?.[0];
if (!targetCombo || !(targetCombo instanceof Array)) return v;
const originConfig = originNode.constructor?.nodeData?.output?.[originSlot];
if (!originConfig || !isValidCombo(targetCombo, originConfig)) {
return false;
}
return v;
};
},
registerCustomNodes() {
LiteGraph.registerNodeType(
"PrimitiveNode",
Object.assign(PrimitiveNode, {
title: "Primitive"
})
);
PrimitiveNode.category = "utils";
}
});
window.comfyAPI = window.comfyAPI || {};
window.comfyAPI.widgetInputs = window.comfyAPI.widgetInputs || {};
window.comfyAPI.widgetInputs.getWidgetConfig = getWidgetConfig;
window.comfyAPI.widgetInputs.convertToInput = convertToInput;
window.comfyAPI.widgetInputs.setWidgetConfig = setWidgetConfig;
window.comfyAPI.widgetInputs.mergeIfValid = mergeIfValid;
export {
convertToInput,
getWidgetConfig,
mergeIfValid,
setWidgetConfig
};
//# sourceMappingURL=widgetInputs-DdoWwzg5.js.map

1
web/assets/widgetInputs-DdoWwzg5.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

View File

@@ -1,2 +1,2 @@
// Shim for extensions\core\clipspace.ts
// Shim for extensions/core/clipspace.ts
export const ClipspaceDialog = window.comfyAPI.clipspace.ClipspaceDialog;

3
web/extensions/core/colorPalette.js vendored Normal file
View File

@@ -0,0 +1,3 @@
// Shim for extensions/core/colorPalette.ts
export const defaultColorPalette = window.comfyAPI.colorPalette.defaultColorPalette;
export const getColorPalette = window.comfyAPI.colorPalette.getColorPalette;

View File

@@ -1,3 +1,3 @@
// Shim for extensions\core\groupNode.ts
// Shim for extensions/core/groupNode.ts
export const GroupNodeConfig = window.comfyAPI.groupNode.GroupNodeConfig;
export const GroupNodeHandler = window.comfyAPI.groupNode.GroupNodeHandler;

View File

@@ -1,2 +1,2 @@
// Shim for extensions\core\groupNodeManage.ts
// Shim for extensions/core/groupNodeManage.ts
export const ManageGroupDialog = window.comfyAPI.groupNodeManage.ManageGroupDialog;

View File

@@ -1,4 +1,5 @@
// Shim for extensions\core\widgetInputs.ts
// Shim for extensions/core/widgetInputs.ts
export const getWidgetConfig = window.comfyAPI.widgetInputs.getWidgetConfig;
export const convertToInput = window.comfyAPI.widgetInputs.convertToInput;
export const setWidgetConfig = window.comfyAPI.widgetInputs.setWidgetConfig;
export const mergeIfValid = window.comfyAPI.widgetInputs.mergeIfValid;

92
web/index.html vendored
View File

@@ -1,50 +1,42 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>ComfyUI</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
<!-- Browser Test Fonts -->
<!-- <link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet">
<link href="https://fonts.googleapis.com/css2?family=Noto+Color+Emoji&family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet">
<style>
* {
font-family: 'Roboto Mono', 'Noto Color Emoji';
}
</style> -->
<link rel="stylesheet" type="text/css" href="user.css" />
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
<script type="module" crossorigin src="./assets/index-CI3N807S.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-_5czGnTA.css">
</head>
<body class="litegraph">
<div id="vue-app"></div>
<div id="comfy-user-selection" class="comfy-user-selection" style="display: none;">
<main class="comfy-user-selection-inner">
<h1>ComfyUI</h1>
<form>
<section>
<label>New user:
<input placeholder="Enter a username" />
</label>
</section>
<div class="comfy-user-existing">
<span class="or-separator">OR</span>
<section>
<label>
Existing user:
<select>
<option hidden disabled selected value> Select a user </option>
</select>
</label>
</section>
</div>
<footer>
<span class="comfy-user-error">&nbsp;</span>
<button class="comfy-btn comfy-user-button-next">Next</button>
</footer>
</form>
</main>
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>ComfyUI</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
<link rel="stylesheet" type="text/css" href="user.css" />
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
<script type="module" crossorigin src="./assets/index-DGAbdBYF.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-BHJGjcJh.css">
</head>
<body class="litegraph grid">
<div id="vue-app"></div>
<div id="comfy-user-selection" class="comfy-user-selection" style="display: none;">
<main class="comfy-user-selection-inner">
<h1>ComfyUI</h1>
<form>
<section>
<label>New user:
<input placeholder="Enter a username" />
</label>
</section>
<div class="comfy-user-existing">
<span class="or-separator">OR</span>
<section>
<label>
Existing user:
<select>
<option hidden disabled selected value> Select a user </option>
</select>
</label>
</section>
</div>
<footer>
<span class="comfy-user-error">&nbsp;</span>
<button class="comfy-btn comfy-user-button-next">Next</button>
</footer>
</form>
</main>
</div>
</body>
</html>

Some files were not shown because too many files have changed in this diff Show More