Compare commits

...

112 Commits

Author SHA1 Message Date
Dr.Lt.Data
5f9d5a244b Hotfix for the div zero occurrence when memory_used_encode is 0 (#5121)
https://github.com/comfyanonymous/ComfyUI/issues/5069#issuecomment-2382656368
2024-10-09 23:34:34 -04:00
Chenlei Hu
14eba07acd Update web content to release v1.3.11 (#5189)
* Update web content to release v1.3.11

* nit
2024-10-09 22:37:04 -04:00
Jonathan Avila
4b2f0d9413 Increase maximum macOS version to 15.0.1 when forcing upcast attention (#5191) 2024-10-09 22:21:41 -04:00
Yoland Yan
25eac1d780 Change runner label for the new runners (#5197) 2024-10-09 20:08:57 -04:00
comfyanonymous
e38c94228b Add a weight_dtype fp8_e4m3fn_fast to the Diffusion Model Loader node.
This is used to load weights in fp8 and use fp8 matrix multiplication.
2024-10-09 19:43:17 -04:00
comfyanonymous
203942c8b2 Fix flux doras with diffusers keys. 2024-10-08 19:03:40 -04:00
Brendan Hoar
3c72c89a52 Update folder_paths.py - try/catch for special file_name values (#5187)
Somehow managed to drop a file called "nul" into a windows checkpoints subdirectory. This caused all sorts of havoc with many nodes that needed the list of checkpoints.
2024-10-08 15:04:32 -04:00
Chenlei Hu
614377abd6 Update web content to release v1.2.64 (#5124) 2024-10-07 17:15:29 -04:00
comfyanonymous
8dfa0cc552 Make SD3 fast previews a little better. 2024-10-07 09:19:59 -04:00
comfyanonymous
e5ecdfdd2d Make fast previews for SDXL a little better by adding a bias. 2024-10-06 19:27:04 -04:00
comfyanonymous
7d29fbf74b Slightly improve the fast previews for flux by adding a bias. 2024-10-06 17:55:46 -04:00
Lex
2c641e64ad IS_CHANGED should be a classmethod (#5159) 2024-10-06 05:47:51 -04:00
comfyanonymous
7d2467e830 Some minor cleanups. 2024-10-05 13:22:39 -04:00
comfyanonymous
6f021d8aa0 Let --verbose have an argument for the log level. 2024-10-04 10:05:34 -04:00
comfyanonymous
d854ed0bcf Allow using SD3 type te output on flux model. 2024-10-03 09:44:54 -04:00
comfyanonymous
abcd006b8c Allow more permutations of clip/t5 in dual clip loader. 2024-10-03 09:26:11 -04:00
comfyanonymous
d985d1d7dc CLIP Loader node now supports clip_l and clip_g only for SD3. 2024-10-02 04:25:17 -04:00
comfyanonymous
d1cdf51e1b Refactor some of the TE detection code. 2024-10-01 07:08:41 -04:00
comfyanonymous
b4626ab93e Add simpletuner lycoris format for SD unet. 2024-09-30 06:03:27 -04:00
comfyanonymous
a9e459c2a4 Use torch.nn.functional.linear in RGB preview code.
Add an optional bias to the latent RGB preview code.
2024-09-29 11:27:49 -04:00
comfyanonymous
3bb4dec720 Fix issue with loras, lowvram and --fast fp8. 2024-09-28 14:42:32 -04:00
City
8733191563 Flux torch.compile fix (#5082) 2024-09-27 22:07:51 -04:00
comfyanonymous
83b01f960a Add backend option to TorchCompileModel.
If you want to use the cudagraphs backend you need to: --disable-cuda-malloc

If you get other backends working feel free to make a PR to add them.
2024-09-27 02:12:37 -04:00
comfyanonymous
d72e871cfa Add a note that the experimental model downloader api will be removed. 2024-09-26 03:17:52 -04:00
comfyanonymous
037c3159b6 Move some nodes out of _for_testing. 2024-09-25 08:41:22 -04:00
comfyanonymous
bdd4a22a2e Fix flux TE not loading t5 embeddings. 2024-09-24 22:57:22 -04:00
comfyanonymous
fdf37566ef Add batch size to EmptyLatentAudio. 2024-09-24 04:32:55 -04:00
Alex "mcmonkey" Goodwin
08c8968482 Internal download API: Add proper validated directory input (#4981)
* add internal /folder_paths route

returns a json maps of folder paths

* (minor) format download_models.py

* initial folder path input on download api

* actually, require folder_path and clean up some code

* partial tests update

* fix & logging

* also download to a tmp file not the live file

to avoid compounding errors from network failure

* update tests again

* test tweaks

* workaround the first tests blocker

* fix file handling in tests

* rewrite test for create_model_path

* minor doc fix

* avoid 'mock_directory'

use temp dir to avoid accidental fs pollution from tests
2024-09-24 03:50:45 -04:00
chaObserv
479a427a48 Add dpmpp_2m_cfg_pp (#4992) 2024-09-24 02:42:56 -04:00
comfyanonymous
3a0eeee320 Make --listen listen on both ipv4 and ipv6 at the same time by default. 2024-09-23 04:38:19 -04:00
comfyanonymous
447da7ea86 Support listening on multiple addresses. 2024-09-23 04:36:59 -04:00
comfyanonymous
9c41bc8d10 Remove useless line. 2024-09-23 02:32:29 -04:00
Robin Huang
6ad0ddbae4 Run unit tests on Windows/MacOS as well. (#5018)
* Run unit tests on Windows as well.

* Test on mac.

* Continue running on error.

* Compared normalized paths to work cross platform.

* Only test common set of mimetypes across operating systems.
2024-09-22 05:01:39 -04:00
RandomGitUser321
a55142f904 Add ws.close() to the websocket examples (#5020)
* add ws.close() to websocket examples

* add and explain ws.close() in websocket examples
2024-09-22 04:59:10 -04:00
comfyanonymous
5718ef69bb Add total and free ram to /system_stats. 2024-09-22 03:42:11 -04:00
RandomGitUser321
13ecf10a92 Added to the websockets_api_example.py to show how to decode latent previews from the binary stream (#5016)
* Update websockets_api_example.py

* even more simplfied
2024-09-22 02:30:44 -04:00
comfyanonymous
7a415f47a9 Add an optional VAE input to the ControlNetApplyAdvanced node.
Deprecate the other controlnet nodes.
2024-09-22 01:24:52 -04:00
Chenlei Hu
89fa2fca24 Update web content to release v1.2.60 (#5017)
* Update web content to release v1.2.60

* Remove dist.zip
2024-09-21 23:28:54 -04:00
comfyanonymous
364b69e931 Make SD3 empty latent image zeros.
This shouldn't change anything. The reason it was not zeros is because it
did matter in early versions of the code.
2024-09-21 09:13:10 -04:00
comfyanonymous
dc96a1ae19 Load controlnet in fp8 if weights are in fp8. 2024-09-21 04:50:12 -04:00
comfyanonymous
2d810b081e Add load_controlnet_state_dict function. 2024-09-21 01:51:51 -04:00
comfyanonymous
9f7e9f0547 Add an error message when a controlnet needs a VAE but none is given. 2024-09-21 01:33:18 -04:00
comfyanonymous
a355f38ecc Make the SD3 controlnet node the default one. 2024-09-21 01:32:46 -04:00
huchenlei
38c69080c7 Add docstring 2024-09-20 03:16:23 -04:00
comfyanonymous
70a708d726 Fix model merging issue. 2024-09-20 02:31:44 -04:00
yoinked
e7d4782736 add laplace scheduler [2407.03297] (#4990)
* add laplace scheduler [2407.03297]

* should be here instead lol

* better settings
2024-09-19 23:23:09 -04:00
Alex "mcmonkey" Goodwin
3326bdfd4e add internal /folder_paths route (#4980)
returns a json maps of folder paths
2024-09-19 09:52:55 -04:00
Alex "mcmonkey" Goodwin
68bb885d22 add 'is_default' to model paths config (#4979)
* add 'is_default' to model paths config

including impl and doc in example file

* update weirdly overspecific test expectations

* oh there's two

* sigh
2024-09-19 08:59:55 -04:00
comfyanonymous
ad66f7c7d8 Add model_options to load_controlnet function. 2024-09-19 08:23:35 -04:00
Simon Lui
de8e8e3b0d Fix xpu Pytorch nightly build from calling optimize which doesn't exist. (#4978) 2024-09-19 05:11:42 -04:00
Alex "mcmonkey" Goodwin
a1e71cfad1 very simple strong-cache on model list (#4969)
* very simple strong-cache on model list

* store the cache after validation too

* only cache object_info for now

* use a 'with' context
2024-09-19 04:40:14 -04:00
comfyanonymous
0bfc7cc998 Create the temp directory on ComfyUI startup instead. 2024-09-18 09:55:57 -04:00
Tom
7183fd1665 Add route to list model types (#4846)
* Add list models route

* Better readable model types list
2024-09-17 04:22:05 -04:00
Alex "mcmonkey" Goodwin
254838f23c add simple error check to model loading (#4950) 2024-09-17 03:57:17 -04:00
pharmapsychotic
0b7dfa986d Improve tiling calculations to reduce number of tiles that need to be processed. (#4944) 2024-09-17 03:51:10 -04:00
comfyanonymous
d514bb38ee Add some option to model_options for the text encoder.
load_device, offload_device and the initial_device can now be set.
2024-09-17 03:49:54 -04:00
comfyanonymous
0849c80e2a get_key_patches now works without unloading the model. 2024-09-17 01:57:59 -04:00
comfyanonymous
56e8f5e4fd VAEDecodeAudio now does some normalization on the audio. 2024-09-16 00:30:36 -04:00
comfyanonymous
e813abbb2c Long CLIP L support for SDXL, SD3 and Flux.
Use the *CLIPLoader nodes.
2024-09-15 07:59:38 -04:00
JettHu
5e68a4ce67 Reduce repeated calls of INPUT_TYPES in cache (#4922) 2024-09-15 01:03:09 -04:00
comfyanonymous
ca08597670 Make the inpaint controlnet node work with non inpaint ones. 2024-09-14 09:17:13 -04:00
comfyanonymous
f48e390032 Support AliMama SD3 and Flux inpaint controlnets.
Use the ControlNetInpaintingAliMamaApply node.
2024-09-14 09:05:16 -04:00
Chenlei Hu
369a6dd2c4 Remove empty spaces in user_manager.py (#4917) 2024-09-13 23:30:44 -04:00
comfyanonymous
b3ce8fb9fd Revert "Reduce repeated calls of get_immediate_node_signature for ancestors in cache (#4871)"
This reverts commit f6b7194f64.
2024-09-13 23:24:47 -04:00
comfyanonymous
cf80d28689 Support loading controlnets with different input. 2024-09-13 09:54:37 -04:00
Acly
6fb44c4b7c Make adding links/nodes to ExecutionList non-recursive (#4886)
Graphs with 300+ chained nodes run into maximum recursion depth error (limit is 1000 in CPython)
2024-09-13 08:25:11 -04:00
Chenlei Hu
d2247c1e61 Normalize path returned by /userdata to always use / as separator (#4906) 2024-09-13 03:45:31 -04:00
Chenlei Hu
cb12ad7049 Add full_info flag in /userdata endpoint to list out file size and last modified timestamp (#4905)
* Add full_info flag in /userdata endpoint to list out file size and last modified timestamp

* nit
2024-09-13 02:40:59 -04:00
JettHu
f6b7194f64 Reduce repeated calls of get_immediate_node_signature for ancestors in cache (#4871) 2024-09-12 23:02:52 -04:00
comfyanonymous
7c6eb4fb29 Set some nodes as DEPRECATED. 2024-09-12 20:27:07 -04:00
Robin Huang
b962db9952 Add cli arg to override user directory (#4856)
* Override user directory.

* Use overridden user directory.

* Remove prints.

* Remove references to global user_files.

* Remove unused replace_folder function.

* Remove newline.

* Remove global during get_user_directory.

* Add validation.
2024-09-12 08:10:27 -04:00
comfyanonymous
d0b7ab88ba Add a simple experimental TorchCompileModel node.
It probably only works on Linux.

For maximum speed on Flux with Nvidia 40 series/ada and newer try using
this node with fp8_e4m3fn and the --fast argument.
2024-09-12 05:24:25 -04:00
Yoland Yan
405b529545 Minor: update tests-unit README.md (#4896) 2024-09-12 04:53:08 -04:00
comfyanonymous
9d720187f1 types -> comfy_types to fix import issue. 2024-09-12 03:57:46 -04:00
Robin Huang
d247bc5a9c Expand variables in base_path for extra_config_paths.yaml. (#4893)
* Expand variables in base_path for extra_config_paths.yaml.

* Fix comments.
2024-09-12 01:52:06 -04:00
comfyanonymous
9f4daca9d9 Doesn't really make sense for cfg_pp sampler to call regular one. 2024-09-11 02:51:36 -04:00
yoinked
b5d0f2a908 Add CFG++ to DPM++ 2S Ancestral (#3871)
* Update sampling.py

* Update samplers.py

* my bad

* "fix" the sampler

* Update samplers.py

* i named it wrong

* minor sampling improvements

mainly using a dynamic rho value (hey this sounds a lot like smea!!!)

* revert rho change

rho? r? its just 1/2
2024-09-11 02:49:44 -04:00
bymyself
e760bf5c40 Add content-type filter method to folder_paths (#4054)
* Add content-type filter method to folder_paths

* Add unit tests

* Hardcode webp content-type

* Annotate content_types as Literal["image", "video", "audio"]
2024-09-11 02:00:07 -04:00
comfyanonymous
36c83cdbba Limit origin check to when host is loopback.
This should still prevent the exploit without breaking things for people
who use reverse proxies.
2024-09-11 01:06:37 -04:00
Yoland Yan
81778a7feb [🗻 Mount Fuji Commit] Add unit tests for folder path utilities (#4869)
All past 30 min of comtts are done on the top of Mt Fuji
By Comfy, Robin, and Yoland
All other comfy org members died on the way

Introduced unit tests to verify the correctness of various folder path
utility functions such as `get_directory_by_type`, `annotated_filepath`,
and `recursive_search` among others. These tests cover scenarios
including directory retrieval, filepath annotation, recursive file
searches, and filtering files by extensions, enhancing the robustness
and reliability of the codebase.
2024-09-10 00:44:49 -04:00
comfyanonymous
bc94662b31 Cleanup. 2024-09-10 00:43:37 -04:00
Robin Huang
9fa8faa44a Expand user directory for basepath in extra_models_paths.yaml (#4857)
* Expand user path.

* Add test.

* Add unit test for expanding base path.

* Simplify unit test.

* Remove comment.

* Remove comment.

* Checkpoints.

* Refactor.
2024-09-10 00:33:44 -04:00
comfyanonymous
9a7444e39f Add diffusion_models to the extra_model_paths.yaml.example 2024-09-10 00:21:33 -04:00
comfyanonymous
54fca4a218 If host does not contain a port only compare the hostnames. 2024-09-09 16:28:23 -04:00
Chenlei Hu
cd4955367e Add back CI action for tests-ui (#4859) 2024-09-09 04:32:55 -04:00
david02871
8354203d95 Add .venv to gitignore (#4756) 2024-09-09 04:31:18 -04:00
comfyanonymous
e0b41243b4 Fix issue where sometimes origin doesn't contain the port. 2024-09-09 03:18:17 -04:00
Alex "mcmonkey" Goodwin
619263d4a6 allow current timestamp in save image prefix (#4030) 2024-09-09 02:55:51 -04:00
comfyanonymous
e3b0402bb7 Ignore origin domain when it's empty. 2024-09-09 01:04:56 -04:00
Darion
967867d48c fix: url decode filename from API (#4801) 2024-09-08 21:02:32 -04:00
comfyanonymous
cbaac71bf5 Fix issue with last commit. 2024-09-08 19:35:23 -04:00
comfyanonymous
3ab3516e46 By default only accept requests where origin header matches the host.
Browsers are dumb and let any website do requests to localhost this should
prevent this without breaking things. CORS prevents the javascript from
reading the response but they can still write it.

At the moment this is only enabled when the --enable-cors-header argument
is not used.
2024-09-08 18:17:29 -04:00
comfyanonymous
9c5fca75f4 Fix lora issue. 2024-09-08 10:10:47 -04:00
guill
a5da4d0b3e Fix error with ExecutionBlocker and OUTPUT_IS_LIST (#4836)
This change resolves an error when a node with OUTPUT_IS_LIST=(True,)
receives an ExecutionBlocker. I've also added a unit test for this case.
2024-09-08 09:48:47 -04:00
comfyanonymous
32a60a7bac Support onetrainer text encoder Flux lora. 2024-09-08 09:31:41 -04:00
Jim Winkens
bb52934ba4 Fix import issue (#4815) 2024-09-07 05:28:32 -04:00
comfyanonymous
8aabd7c8c0 SaveLora node can now save "full diff" lora format.
This isn't actually a lora format and is saving the full diff of the
weights in a format that can be used in the lora loader nodes.
2024-09-07 03:21:02 -04:00
comfyanonymous
a09b29ca11 Add an option to the SaveLora node to store the bias diff. 2024-09-07 03:03:30 -04:00
comfyanonymous
9bfee68773 LoraSave node now supports generating text encoder loras.
text_encoder_diff should be connected to a CLIPMergeSubtract node.

model_diff and text_encoder_diff are optional inputs so you can create
model only loras, text encoder only loras or a lora that contains both.
2024-09-07 02:30:12 -04:00
comfyanonymous
ea77750759 Support a generic Comfy format for text encoder loras.
This is a format with keys like:
text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.lora_up.weight

Instead of waiting for me to add support for specific lora formats you can
convert your text encoder loras to this format instead.

If you want to see an example save a text encoder lora with the SaveLora
node with the commit right after this one.
2024-09-07 02:20:39 -04:00
comfyanonymous
c27ebeb1c2 Fix onnx export not working on flux. 2024-09-06 03:21:52 -04:00
guill
0c7c98a965 Nodes using UNIQUE_ID as input are NOT_IDEMPOTENT (#4793)
As suggested by @ltdrdata, we can automatically consider nodes that take
the UNIQUE_ID hidden input to be NOT_IDEMPOTENT.
2024-09-05 19:33:02 -04:00
comfyanonymous
dc2eb75b85 Update stable release workflow to latest pytorch with cuda 12.4. 2024-09-05 19:21:52 -04:00
Chenlei Hu
fa34efe3bd Update frontend to v1.2.47 (#4798)
* Update web content to release v1.2.47

* Update shortcut list
2024-09-05 18:56:01 -04:00
comfyanonymous
5cbaa9e07c Mistoline flux controlnet support. 2024-09-05 00:05:17 -04:00
comfyanonymous
c7427375ee Prioritize freeing partially offloaded models first. 2024-09-04 19:47:32 -04:00
comfyanonymous
22d1241a50 Add an experimental LoraSave node to extract model loras.
The model_diff input should be connected to the output of a
ModelMergeSubtract node.
2024-09-04 16:38:38 -04:00
Jedrzej Kosinski
f04229b84d Add emb_patch support to UNetModel forward (#4779) 2024-09-04 14:35:15 -04:00
Silver
f067ad15d1 Make live preview size a configurable launch argument (#4649)
* Make live preview size a configurable launch argument

* Remove import from testing phase

* Update cli_args.py
2024-09-03 19:16:38 -04:00
comfyanonymous
483004dd1d Support newer glora format. 2024-09-03 17:02:19 -04:00
comfyanonymous
00a5d08103 Lower fp8 lora memory usage. 2024-09-03 01:25:05 -04:00
comfyanonymous
d043997d30 Flux onetrainer lora. 2024-09-02 08:22:15 -04:00
139 changed files with 111976 additions and 87396 deletions

View File

@@ -23,7 +23,7 @@ jobs:
runner_label: [self-hosted, Linux] runner_label: [self-hosted, Linux]
flags: "" flags: ""
- os: windows - os: windows
runner_label: [self-hosted, win] runner_label: [self-hosted, Windows]
flags: "" flags: ""
runs-on: ${{ matrix.runner_label }} runs-on: ${{ matrix.runner_label }}
steps: steps:

View File

@@ -12,7 +12,7 @@ on:
description: 'CUDA version' description: 'CUDA version'
required: true required: true
type: string type: string
default: "121" default: "124"
python_minor: python_minor:
description: 'Python minor version' description: 'Python minor version'
required: true required: true

View File

@@ -32,7 +32,7 @@ jobs:
runner_label: [self-hosted, Linux] runner_label: [self-hosted, Linux]
flags: "" flags: ""
- os: windows - os: windows
runner_label: [self-hosted, win] runner_label: [self-hosted, Windows]
flags: "" flags: ""
runs-on: ${{ matrix.runner_label }} runs-on: ${{ matrix.runner_label }}
steps: steps:
@@ -55,7 +55,7 @@ jobs:
torch_version: ["nightly"] torch_version: ["nightly"]
include: include:
- os: windows - os: windows
runner_label: [self-hosted, win] runner_label: [self-hosted, Windows]
flags: "" flags: ""
runs-on: ${{ matrix.runner_label }} runs-on: ${{ matrix.runner_label }}
steps: steps:

30
.github/workflows/test-unit.yml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: Unit Tests
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.os }}
continue-on-error: true
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Run Unit Tests
run: |
pip install -r tests-unit/requirements.txt
python -m pytest tests-unit

1
.gitignore vendored
View File

@@ -12,6 +12,7 @@ extra_model_paths.yaml
.vscode/ .vscode/
.idea/ .idea/
venv/ venv/
.venv/
/web/extensions/* /web/extensions/*
!/web/extensions/logging.js.example !/web/extensions/logging.js.example
!/web/extensions/core/ !/web/extensions/core/

View File

@@ -94,6 +94,8 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Alt + `+` | Canvas Zoom in | | Alt + `+` | Canvas Zoom in |
| Alt + `-` | Canvas Zoom out | | Alt + `-` | Canvas Zoom out |
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out | | Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
| P | Pin/Unpin selected nodes |
| Ctrl + G | Group selected nodes |
| Q | Toggle visibility of the queue | | Q | Toggle visibility of the queue |
| H | Toggle visibility of history | | H | Toggle visibility of history |
| R | Refresh graph | | R | Refresh graph |

View File

@@ -1,6 +1,6 @@
from aiohttp import web from aiohttp import web
from typing import Optional from typing import Optional
from folder_paths import models_dir, user_directory, output_directory from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
from api_server.services.file_service import FileService from api_server.services.file_service import FileService
import app.logger import app.logger
@@ -36,6 +36,13 @@ class InternalRoutes:
async def get_logs(request): async def get_logs(request):
return web.json_response(app.logger.get_logs()) return web.json_response(app.logger.get_logs())
@self.routes.get('/folder_paths')
async def get_folder_paths(request):
response = {}
for key in folder_names_and_paths:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)
def get_app(self): def get_app(self):
if self._app is None: if self._app is None:
self._app = web.Application() self._app = web.Application()

View File

@@ -10,14 +10,14 @@ def get_logs():
return "\n".join([formatter.format(x) for x in logs]) return "\n".join([formatter.format(x) for x in logs])
def setup_logger(verbose: bool = False, capacity: int = 300): def setup_logger(log_level: str = 'INFO', capacity: int = 300):
global logs global logs
if logs: if logs:
return return
# Setup default global logger # Setup default global logger
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.DEBUG if verbose else logging.INFO) logger.setLevel(log_level)
stream_handler = logging.StreamHandler() stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s")) stream_handler.setFormatter(logging.Formatter("%(message)s"))

View File

@@ -5,17 +5,17 @@ import uuid
import glob import glob
import shutil import shutil
from aiohttp import web from aiohttp import web
from urllib import parse
from comfy.cli_args import args from comfy.cli_args import args
from folder_paths import user_directory import folder_paths
from .app_settings import AppSettings from .app_settings import AppSettings
default_user = "default" default_user = "default"
users_file = os.path.join(user_directory, "users.json")
class UserManager(): class UserManager():
def __init__(self): def __init__(self):
global user_directory user_directory = folder_paths.get_user_directory()
self.settings = AppSettings(self) self.settings = AppSettings(self)
if not os.path.exists(user_directory): if not os.path.exists(user_directory):
@@ -25,14 +25,17 @@ class UserManager():
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******") print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user: if args.multi_user:
if os.path.isfile(users_file): if os.path.isfile(self.get_users_file()):
with open(users_file) as f: with open(self.get_users_file()) as f:
self.users = json.load(f) self.users = json.load(f)
else: else:
self.users = {} self.users = {}
else: else:
self.users = {"default": "default"} self.users = {"default": "default"}
def get_users_file(self):
return os.path.join(folder_paths.get_user_directory(), "users.json")
def get_request_user_id(self, request): def get_request_user_id(self, request):
user = "default" user = "default"
if args.multi_user and "comfy-user" in request.headers: if args.multi_user and "comfy-user" in request.headers:
@@ -44,7 +47,7 @@ class UserManager():
return user return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True): def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
global user_directory user_directory = folder_paths.get_user_directory()
if type == "userdata": if type == "userdata":
root_dir = user_directory root_dir = user_directory
@@ -59,6 +62,10 @@ class UserManager():
return None return None
if file is not None: if file is not None:
# Check if filename is url encoded
if "%" in file:
file = parse.unquote(file)
# prevent leaving /{type}/{user} # prevent leaving /{type}/{user}
path = os.path.abspath(os.path.join(user_root, file)) path = os.path.abspath(os.path.join(user_root, file))
if os.path.commonpath((user_root, path)) != user_root: if os.path.commonpath((user_root, path)) != user_root:
@@ -80,8 +87,7 @@ class UserManager():
self.users[user_id] = name self.users[user_id] = name
global users_file with open(self.get_users_file(), "w") as f:
with open(users_file, "w") as f:
json.dump(self.users, f) json.dump(self.users, f)
return user_id return user_id
@@ -112,25 +118,69 @@ class UserManager():
@routes.get("/userdata") @routes.get("/userdata")
async def listuserdata(request): async def listuserdata(request):
"""
List user data files in a specified directory.
This endpoint allows listing files in a user's data directory, with options for recursion,
full file information, and path splitting.
Query Parameters:
- dir (required): The directory to list files from.
- recurse (optional): If "true", recursively list files in subdirectories.
- full_info (optional): If "true", return detailed file information (path, size, modified time).
- split (optional): If "true", split file paths into components (only applies when full_info is false).
Returns:
- 400: If 'dir' parameter is missing.
- 403: If the requested path is not allowed.
- 404: If the requested directory does not exist.
- 200: JSON response with the list of files or file information.
The response format depends on the query parameters:
- Default: List of relative file paths.
- full_info=true: List of dictionaries with file details.
- split=true (and full_info=false): List of lists, each containing path components.
"""
directory = request.rel_url.query.get('dir', '') directory = request.rel_url.query.get('dir', '')
if not directory: if not directory:
return web.Response(status=400) return web.Response(status=400, text="Directory not provided")
path = self.get_request_user_filepath(request, directory) path = self.get_request_user_filepath(request, directory)
if not path: if not path:
return web.Response(status=403) return web.Response(status=403, text="Invalid directory")
if not os.path.exists(path): if not os.path.exists(path):
return web.Response(status=404) return web.Response(status=404, text="Directory not found")
recurse = request.rel_url.query.get('recurse', '').lower() == "true" recurse = request.rel_url.query.get('recurse', '').lower() == "true"
results = glob.glob(os.path.join( full_info = request.rel_url.query.get('full_info', '').lower() == "true"
glob.escape(path), '**/*'), recursive=recurse)
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)] # Use different patterns based on whether we're recursing or not
if recurse:
pattern = os.path.join(glob.escape(path), '**', '*')
else:
pattern = os.path.join(glob.escape(path), '*')
results = glob.glob(pattern, recursive=recurse)
if full_info:
results = [
{
'path': os.path.relpath(x, path).replace(os.sep, '/'),
'size': os.path.getsize(x),
'modified': os.path.getmtime(x)
} for x in results if os.path.isfile(x)
]
else:
results = [
os.path.relpath(x, path).replace(os.sep, '/')
for x in results
if os.path.isfile(x)
]
split_path = request.rel_url.query.get('split', '').lower() == "true" split_path = request.rel_url.query.get('split', '').lower() == "true"
if split_path: if split_path and not full_info:
results = [[x] + x.split(os.sep) for x in results] results = [[x] + x.split('/') for x in results]
return web.json_response(results) return web.json_response(results)
@@ -138,14 +188,14 @@ class UserManager():
file = request.match_info.get(param, None) file = request.match_info.get(param, None)
if not file: if not file:
return web.Response(status=400) return web.Response(status=400)
path = self.get_request_user_filepath(request, file) path = self.get_request_user_filepath(request, file)
if not path: if not path:
return web.Response(status=403) return web.Response(status=403)
if check_exists and not os.path.exists(path): if check_exists and not os.path.exists(path):
return web.Response(status=404) return web.Response(status=404)
return path return path
@routes.get("/userdata/{file}") @routes.get("/userdata/{file}")
@@ -153,7 +203,7 @@ class UserManager():
path = get_user_data_path(request, check_exists=True) path = get_user_data_path(request, check_exists=True)
if not isinstance(path, str): if not isinstance(path, str):
return path return path
return web.FileResponse(path) return web.FileResponse(path)
@routes.post("/userdata/{file}") @routes.post("/userdata/{file}")
@@ -161,7 +211,7 @@ class UserManager():
path = get_user_data_path(request) path = get_user_data_path(request)
if not isinstance(path, str): if not isinstance(path, str):
return path return path
overwrite = request.query["overwrite"] != "false" overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(path): if not overwrite and os.path.exists(path):
return web.Response(status=409) return web.Response(status=409)
@@ -170,7 +220,7 @@ class UserManager():
with open(path, "wb") as f: with open(path, "wb") as f:
f.write(body) f.write(body)
resp = os.path.relpath(path, self.get_request_user_filepath(request, None)) resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
return web.json_response(resp) return web.json_response(resp)
@@ -181,7 +231,7 @@ class UserManager():
return path return path
os.remove(path) os.remove(path)
return web.Response(status=204) return web.Response(status=204)
@routes.post("/userdata/{file}/move/{dest}") @routes.post("/userdata/{file}/move/{dest}")
@@ -189,17 +239,17 @@ class UserManager():
source = get_user_data_path(request, check_exists=True) source = get_user_data_path(request, check_exists=True)
if not isinstance(source, str): if not isinstance(source, str):
return source return source
dest = get_user_data_path(request, check_exists=False, param="dest") dest = get_user_data_path(request, check_exists=False, param="dest")
if not isinstance(source, str): if not isinstance(source, str):
return dest return dest
overwrite = request.query["overwrite"] != "false" overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(dest): if not overwrite and os.path.exists(dest):
return web.Response(status=409) return web.Response(status=409)
print(f"moving '{source}' -> '{dest}'") print(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest) shutil.move(source, dest)
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None)) resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
return web.json_response(resp) return web.json_response(resp)

View File

@@ -6,6 +6,7 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
def __init__( def __init__(
self, self,
num_blocks = None, num_blocks = None,
control_latent_channels = None,
dtype = None, dtype = None,
device = None, device = None,
operations = None, operations = None,
@@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
for _ in range(len(self.joint_blocks)): for _ in range(len(self.joint_blocks)):
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype)) self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
if control_latent_channels is None:
control_latent_channels = self.in_channels
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed( self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
None, None,
self.patch_size, self.patch_size,
self.in_channels, control_latent_channels,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
strict_img_size=False, strict_img_size=False,

View File

@@ -36,7 +36,7 @@ class EnumAction(argparse.Action):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function") parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function") parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
@@ -92,6 +92,8 @@ class LatentPreviewMethod(enum.Enum):
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
cache_group = parser.add_mutually_exclusive_group() cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
@@ -134,7 +136,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.") parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
# The default built-in provider hosted under web/ # The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
@@ -169,6 +171,8 @@ parser.add_argument(
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.", help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
) )
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
if comfy.options.args_parsing: if comfy.options.args_parsing:
args = parser.parse_args() args = parser.parse_args()
else: else:

View File

@@ -109,8 +109,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
if k not in u: if k not in u:
t = sd.pop(k) sd.pop(k)
del t
return clip return clip
def load(ckpt_path): def load(ckpt_path):

View File

@@ -79,13 +79,21 @@ class ControlBase:
self.previous_controlnet = None self.previous_controlnet = None
self.extra_conds = [] self.extra_conds = []
self.strength_type = StrengthType.CONSTANT self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
self.extra_concat_orig = []
self.extra_concat = None
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
self.strength = strength self.strength = strength
self.timestep_percent_range = timestep_percent_range self.timestep_percent_range = timestep_percent_range
if self.latent_format is not None: if self.latent_format is not None:
if vae is None:
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
self.vae = vae self.vae = vae
self.extra_concat_orig = extra_concat.copy()
if self.concat_mask and len(self.extra_concat_orig) == 0:
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
return self return self
def pre_run(self, model, percent_to_timestep_function): def pre_run(self, model, percent_to_timestep_function):
@@ -100,9 +108,9 @@ class ControlBase:
def cleanup(self): def cleanup(self):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
self.previous_controlnet.cleanup() self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint self.cond_hint = None
self.cond_hint = None self.extra_concat = None
self.timestep_range = None self.timestep_range = None
def get_models(self): def get_models(self):
@@ -123,6 +131,8 @@ class ControlBase:
c.vae = self.vae c.vae = self.vae
c.extra_conds = self.extra_conds.copy() c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type c.strength_type = self.strength_type
c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy()
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
@@ -175,7 +185,7 @@ class ControlBase:
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.load_device = load_device self.load_device = load_device
@@ -189,6 +199,7 @@ class ControlNet(ControlBase):
self.latent_format = latent_format self.latent_format = latent_format
self.extra_conds += extra_conds self.extra_conds += extra_conds
self.strength_type = strength_type self.strength_type = strength_type
self.concat_mask = concat_mask
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
@@ -213,6 +224,9 @@ class ControlNet(ControlBase):
compression_ratio = self.compression_ratio compression_ratio = self.compression_ratio
if self.vae is not None: if self.vae is not None:
compression_ratio *= self.vae.downscale_ratio compression_ratio *= self.vae.downscale_ratio
else:
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None: if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True) loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
@@ -220,6 +234,13 @@ class ControlNet(ControlBase):
comfy.model_management.load_models_gpu(loaded_models) comfy.model_management.load_models_gpu(loaded_models)
if self.latent_format is not None: if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint) self.cond_hint = self.latent_format.process_in(self.cond_hint)
if len(self.extra_concat_orig) > 0:
to_concat = []
for c in self.extra_concat_orig:
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype) self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]: if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
@@ -319,7 +340,7 @@ class ControlLoraOps:
class ControlLora(ControlNet): class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None): def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options
ControlBase.__init__(self, device) ControlBase.__init__(self, device)
self.control_weights = control_weights self.control_weights = control_weights
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
@@ -376,19 +397,25 @@ class ControlLora(ControlNet):
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def controlnet_config(sd): def controlnet_config(sd, model_options={}):
model_config = comfy.model_detection.model_config_from_unet(sd, "", True) model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
supported_inference_dtypes = model_config.supported_inference_dtypes unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(sd)
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
controlnet_config = model_config.unet_config
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device() load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast operations = model_options.get("custom_operations", None)
else: if operations is None:
operations = comfy.ops.disable_weight_init operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
offload_device = comfy.model_management.unet_offload_device() offload_device = comfy.model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
@@ -403,24 +430,29 @@ def controlnet_load_state_dict(control_model, sd):
logging.debug("unexpected controlnet keys: {}".format(unexpected)) logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model return control_model
def load_controlnet_mmdit(sd): def load_controlnet_mmdit(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.') num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd: for k in sd:
new_sd[k] = sd[k] new_sd[k] = sd[k]
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) concat_mask = False
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
if control_latent_channels == 17: #inpaint controlnet
concat_mask = True
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd) control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.SD3() latent_format = comfy.latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness latent_format.shift_factor = 0 #SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control return control
def load_controlnet_hunyuandit(controlnet_data): def load_controlnet_hunyuandit(controlnet_data, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype) control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data) control_model = controlnet_load_state_dict(control_model, controlnet_data)
@@ -430,17 +462,17 @@ def load_controlnet_hunyuandit(controlnet_data):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT) control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control return control
def load_controlnet_flux_xlabs(sd): def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd) control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance'] extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control return control
def load_controlnet_flux_instantx(sd): def load_controlnet_flux_instantx(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
for k in sd: for k in sd:
new_sd[k] = sd[k] new_sd[k] = sd[k]
@@ -449,21 +481,30 @@ def load_controlnet_flux_instantx(sd):
if union_cnet in new_sd: if union_cnet in new_sd:
num_union_modes = new_sd[union_cnet].shape[0] num_union_modes = new_sd[union_cnet].shape[0]
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
concat_mask = False
if control_latent_channels == 17:
concat_mask = True
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd) control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.Flux() latent_format = comfy.latent_formats.Flux()
extra_conds = ['y', 'guidance'] extra_conds = ['y', 'guidance']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control return control
def load_controlnet(ckpt_path, model=None): def convert_mistoline(sd):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
controlnet_data = state_dict
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data) return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
if "lora_controlnet" in controlnet_data: if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data) return ControlLora(controlnet_data, model_options=model_options)
controlnet_config = None controlnet_config = None
supported_inference_dtypes = None supported_inference_dtypes = None
@@ -518,13 +559,15 @@ def load_controlnet(ckpt_path, model=None):
if len(leftover_keys) > 0: if len(leftover_keys) > 0:
logging.warning("leftover keys: {}".format(leftover_keys)) logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format elif "controlnet_blocks.0.weight" in controlnet_data:
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data: if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs(controlnet_data) return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data: elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data) return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data: elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data) return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
pth_key = 'control_model.zero_convs.0.0.weight' pth_key = 'control_model.zero_convs.0.0.weight'
pth = False pth = False
@@ -536,25 +579,36 @@ def load_controlnet(ckpt_path, model=None):
elif key in controlnet_data: elif key in controlnet_data:
prefix = "" prefix = ""
else: else:
net = load_t2i_adapter(controlnet_data) net = load_t2i_adapter(controlnet_data, model_options=model_options)
if net is None: if net is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path)) logging.error("error could not detect control model type.")
return net return net
if controlnet_config is None: if controlnet_config is None:
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True) model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
supported_inference_dtypes = model_config.supported_inference_dtypes supported_inference_dtypes = list(model_config.supported_inference_dtypes)
controlnet_config = model_config.unet_config controlnet_config = model_config.unet_config
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
if supported_inference_dtypes is None:
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device() load_device = comfy.model_management.get_torch_device()
if supported_inference_dtypes is None:
unet_dtype = comfy.model_management.unet_dtype()
else:
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None: operations = model_options.get("custom_operations", None)
controlnet_config["operations"] = comfy.ops.manual_cast if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
controlnet_config["operations"] = operations
controlnet_config["dtype"] = unet_dtype controlnet_config["dtype"] = unet_dtype
controlnet_config["device"] = comfy.model_management.unet_offload_device() controlnet_config["device"] = comfy.model_management.unet_offload_device()
controlnet_config.pop("out_channels") controlnet_config.pop("out_channels")
@@ -590,14 +644,21 @@ def load_controlnet(ckpt_path, model=None):
if len(unexpected) > 0: if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected)) logging.debug("unexpected controlnet keys: {}".format(unexpected))
global_average_pooling = False global_average_pooling = model_options.get("global_average_pooling", False)
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control return control
def load_controlnet(ckpt_path, model=None, model_options={}):
if "global_average_pooling" not in model_options:
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
model_options["global_average_pooling"] = True
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
if cnet is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return cnet
class T2IAdapter(ControlBase): class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None): def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
super().__init__(device) super().__init__(device)
@@ -653,7 +714,7 @@ class T2IAdapter(ControlBase):
self.copy_to(c) self.copy_to(c)
return c return c
def load_t2i_adapter(t2i_data): def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8 compression_ratio = 8
upscale_algorithm = 'nearest-exact' upscale_algorithm = 'nearest-exact'

View File

@@ -1,5 +1,4 @@
import torch import torch
import math
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None): def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
mantissa_scaled = torch.where( mantissa_scaled = torch.where(
@@ -41,9 +40,8 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None):
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x), (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x (2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
) )
del abs_x
return sign.to(dtype=dtype) return sign
@@ -57,6 +55,11 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device) generator = torch.Generator(device=value.device)
generator.manual_seed(seed) generator.manual_seed(seed)
return manual_stochastic_round_to_float8(value, dtype, generator=generator) output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
return output
return value.to(dtype=dtype) return value.to(dtype=dtype)

View File

@@ -44,6 +44,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
return append_zero(sigmas) return append_zero(sigmas)
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
epsilon = 1e-5 # avoid log(0)
x = torch.linspace(0, 1, n, device=device)
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
sigmas = clamp(torch.exp(lmb))
return sigmas
def to_d(x, sigma, denoised): def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative.""" """Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / utils.append_dims(sigma, x.ndim) return (x - denoised) / utils.append_dims(sigma, x.ndim)
@@ -1101,3 +1112,78 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
if sigmas[i + 1] > 0: if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigma_down == 0:
# Euler method
d = to_d(x, sigmas[i], temp[0])
dt = sigma_down - sigmas[i]
x = denoised + d * sigma_down
else:
# DPM-Solver++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
r = 1 / 2
h = t_next - t
s = t + r * h
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
# Noise addition
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
t_fn = lambda sigma: sigma.log().neg()
old_uncond_denoised = None
uncond_denoised = None
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_uncond_denoised is None or sigmas[i + 1] == 0:
denoised_mix = -torch.exp(-h) * uncond_denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
x = denoised + denoised_mix + torch.exp(-h) * x
old_uncond_denoised = uncond_denoised
return x

View File

@@ -4,6 +4,7 @@ class LatentFormat:
scale_factor = 1.0 scale_factor = 1.0
latent_channels = 4 latent_channels = 4
latent_rgb_factors = None latent_rgb_factors = None
latent_rgb_factors_bias = None
taesd_decoder_name = None taesd_decoder_name = None
def process_in(self, latent): def process_in(self, latent):
@@ -30,11 +31,13 @@ class SDXL(LatentFormat):
def __init__(self): def __init__(self):
self.latent_rgb_factors = [ self.latent_rgb_factors = [
# R G B # R G B
[ 0.3920, 0.4054, 0.4549], [ 0.3651, 0.4232, 0.4341],
[-0.2634, -0.0196, 0.0653], [-0.2533, -0.0042, 0.1068],
[ 0.0568, 0.1687, -0.0755], [ 0.1076, 0.1111, -0.0362],
[-0.3112, -0.2359, -0.2076] [-0.3165, -0.2492, -0.2188]
] ]
self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011]
self.taesd_decoder_name = "taesdxl_decoder" self.taesd_decoder_name = "taesdxl_decoder"
class SDXL_Playground_2_5(LatentFormat): class SDXL_Playground_2_5(LatentFormat):
@@ -112,23 +115,24 @@ class SD3(LatentFormat):
self.scale_factor = 1.5305 self.scale_factor = 1.5305
self.shift_factor = 0.0609 self.shift_factor = 0.0609
self.latent_rgb_factors = [ self.latent_rgb_factors = [
[-0.0645, 0.0177, 0.1052], [-0.0922, -0.0175, 0.0749],
[ 0.0028, 0.0312, 0.0650], [ 0.0311, 0.0633, 0.0954],
[ 0.1848, 0.0762, 0.0360], [ 0.1994, 0.0927, 0.0458],
[ 0.0944, 0.0360, 0.0889], [ 0.0856, 0.0339, 0.0902],
[ 0.0897, 0.0506, -0.0364], [ 0.0587, 0.0272, -0.0496],
[-0.0020, 0.1203, 0.0284], [-0.0006, 0.1104, 0.0309],
[ 0.0855, 0.0118, 0.0283], [ 0.0978, 0.0306, 0.0427],
[-0.0539, 0.0658, 0.1047], [-0.0042, 0.1038, 0.1358],
[-0.0057, 0.0116, 0.0700], [-0.0194, 0.0020, 0.0669],
[-0.0412, 0.0281, -0.0039], [-0.0488, 0.0130, -0.0268],
[ 0.1106, 0.1171, 0.1220], [ 0.0922, 0.0988, 0.0951],
[-0.0248, 0.0682, -0.0481], [-0.0278, 0.0524, -0.0542],
[ 0.0815, 0.0846, 0.1207], [ 0.0332, 0.0456, 0.0895],
[-0.0120, -0.0055, -0.0867], [-0.0069, -0.0030, -0.0810],
[-0.0749, -0.0634, -0.0456], [-0.0596, -0.0465, -0.0293],
[-0.1418, -0.1457, -0.1259] [-0.1448, -0.1463, -0.1189]
] ]
self.latent_rgb_factors_bias = [0.2394, 0.2135, 0.1925]
self.taesd_decoder_name = "taesd3_decoder" self.taesd_decoder_name = "taesd3_decoder"
def process_in(self, latent): def process_in(self, latent):
@@ -146,23 +150,24 @@ class Flux(SD3):
self.scale_factor = 0.3611 self.scale_factor = 0.3611
self.shift_factor = 0.1159 self.shift_factor = 0.1159
self.latent_rgb_factors =[ self.latent_rgb_factors =[
[-0.0404, 0.0159, 0.0609], [-0.0346, 0.0244, 0.0681],
[ 0.0043, 0.0298, 0.0850], [ 0.0034, 0.0210, 0.0687],
[ 0.0328, -0.0749, -0.0503], [ 0.0275, -0.0668, -0.0433],
[-0.0245, 0.0085, 0.0549], [-0.0174, 0.0160, 0.0617],
[ 0.0966, 0.0894, 0.0530], [ 0.0859, 0.0721, 0.0329],
[ 0.0035, 0.0399, 0.0123], [ 0.0004, 0.0383, 0.0115],
[ 0.0583, 0.1184, 0.1262], [ 0.0405, 0.0861, 0.0915],
[-0.0191, -0.0206, -0.0306], [-0.0236, -0.0185, -0.0259],
[-0.0324, 0.0055, 0.1001], [-0.0245, 0.0250, 0.1180],
[ 0.0955, 0.0659, -0.0545], [ 0.1008, 0.0755, -0.0421],
[-0.0504, 0.0231, -0.0013], [-0.0515, 0.0201, 0.0011],
[ 0.0500, -0.0008, -0.0088], [ 0.0428, -0.0012, -0.0036],
[ 0.0982, 0.0941, 0.0976], [ 0.0817, 0.0765, 0.0749],
[-0.1233, -0.0280, -0.0897], [-0.1264, -0.0522, -0.1103],
[-0.0005, -0.0530, -0.0020], [-0.0280, -0.0881, -0.0499],
[-0.1273, -0.0932, -0.0680] [-0.1262, -0.0982, -0.0778]
] ]
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.taesd_decoder_name = "taef1_decoder" self.taesd_decoder_name = "taef1_decoder"
def process_in(self, latent): def process_in(self, latent):

View File

@@ -14,7 +14,7 @@ except:
rms_norm_torch = None rms_norm_torch = None
def rms_norm(x, weight, eps=1e-6): def rms_norm(x, weight, eps=1e-6):
if rms_norm_torch is not None: if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else: else:
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)

View File

@@ -1,4 +1,5 @@
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py #Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
#modified to support different types of flux controlnets
import torch import torch
import math import math
@@ -12,22 +13,65 @@ from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
from .model import Flux from .model import Flux
import comfy.ldm.common_dit import comfy.ldm.common_dit
class MistolineCondDownsamplBlock(nn.Module):
def __init__(self, dtype=None, device=None, operations=None):
super().__init__()
self.encoder = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward(self, x):
return self.encoder(x)
class MistolineControlnetBlock(nn.Module):
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.linear(x))
class ControlNetFlux(Flux): class ControlNetFlux(Flux):
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs): def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.main_model_double = 19 self.main_model_double = 19
self.main_model_single = 38 self.main_model_single = 38
self.mistoline = mistoline
# add ControlNet blocks # add ControlNet blocks
if self.mistoline:
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks = nn.ModuleList([]) self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth): for _ in range(self.params.depth):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) self.controlnet_blocks.append(control_block())
self.controlnet_blocks.append(controlnet_block)
self.controlnet_single_blocks = nn.ModuleList([]) self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(self.params.depth_single_blocks): for _ in range(self.params.depth_single_blocks):
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)) self.controlnet_single_blocks.append(control_block())
self.num_union_modes = num_union_modes self.num_union_modes = num_union_modes
self.controlnet_mode_embedder = None self.controlnet_mode_embedder = None
@@ -36,25 +80,33 @@ class ControlNetFlux(Flux):
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.latent_input = latent_input self.latent_input = latent_input
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) if control_latent_channels is None:
control_latent_channels = self.in_channels
else:
control_latent_channels *= 2 * 2 #patch size
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
if not self.latent_input: if not self.latent_input:
self.input_hint_block = nn.Sequential( if self.mistoline:
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device), self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
nn.SiLU(), else:
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), self.input_hint_block = nn.Sequential(
nn.SiLU(), operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), nn.SiLU(),
nn.SiLU(), operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), nn.SiLU(),
nn.SiLU(), operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), nn.SiLU(),
nn.SiLU(), operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), nn.SiLU(),
nn.SiLU(), operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), nn.SiLU(),
nn.SiLU(), operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device) nn.SiLU(),
) operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward_orig( def forward_orig(
self, self,
@@ -73,9 +125,6 @@ class ControlNetFlux(Flux):
# running on sequences img # running on sequences img
img = self.img_in(img) img = self.img_in(img)
if not self.latent_input:
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond) controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond img = img + controlnet_cond
@@ -131,9 +180,14 @@ class ControlNetFlux(Flux):
patch_size = 2 patch_size = 2
if self.latent_input: if self.latent_input:
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size)) hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) elif self.mistoline:
hint = hint * 2.0 - 1.0
hint = self.input_cond_block(hint)
else: else:
hint = hint * 2.0 - 1.0 hint = hint * 2.0 - 1.0
hint = self.input_hint_block(hint)
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
bs, c, h, w = x.shape bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))

View File

@@ -108,7 +108,7 @@ class Flux(nn.Module):
raise ValueError("Didn't get guidance strength for guidance distilled model.") raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y) vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt) txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1) ids = torch.cat((txt_ids, img_ids), dim=1)
@@ -151,8 +151,8 @@ class Flux(nn.Module):
h_len = ((h + (patch_size // 2)) // patch_size) h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None] img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :] img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)

View File

@@ -842,6 +842,11 @@ class UNetModel(nn.Module):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
if "emb_patch" in transformer_patches:
patch = transformer_patches["emb_patch"]
for p in patch:
emb = p(emb, self.model_channels, transformer_options)
if self.num_classes is not None: if self.num_classes is not None:
assert y.shape[0] == x.shape[0] assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)

View File

@@ -201,9 +201,13 @@ def load_lora(lora, to_load):
def model_lora_keys_clip(model, key_map={}): def model_lora_keys_clip(model, key_map={}):
sdk = model.state_dict().keys() sdk = model.state_dict().keys()
for k in sdk:
if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False clip_l_present = False
clip_g_present = False
for b in range(32): #TODO: clean up for b in range(32): #TODO: clean up
for c in LORA_CLIP_MAP: for c in LORA_CLIP_MAP:
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
@@ -227,6 +231,7 @@ def model_lora_keys_clip(model, key_map={}):
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk: if k in sdk:
clip_g_present = True
if clip_l_present: if clip_l_present:
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k key_map[lora_key] = k
@@ -242,10 +247,18 @@ def model_lora_keys_clip(model, key_map={}):
for k in sdk: for k in sdk:
if k.endswith(".weight"): if k.endswith(".weight"):
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
l_key = k[len("t5xxl.transformer."):-len(".weight")] l_key = k[len("t5xxl.transformer."):-len(".weight")]
lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) t5_index = 1
key_map[lora_key] = k if clip_g_present:
t5_index += 1
if clip_l_present:
t5_index += 1
if t5_index == 2:
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
t5_index += 1
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")] l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
lora_key = "lora_te1_{}".format(l_key.replace(".", "_")) lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
@@ -281,6 +294,7 @@ def model_lora_keys_unet(model, key_map={}):
unet_key = "diffusion_model.{}".format(diffusers_keys[k]) unet_key = "diffusion_model.{}".format(diffusers_keys[k])
key_lora = k[:-len(".weight")].replace(".", "_") key_lora = k[:-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = unet_key key_map["lora_unet_{}".format(key_lora)] = unet_key
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
diffusers_lora_prefix = ["", "unet."] diffusers_lora_prefix = ["", "unet."]
for p in diffusers_lora_prefix: for p in diffusers_lora_prefix:
@@ -324,14 +338,15 @@ def model_lora_keys_unet(model, key_map={}):
to = diffusers_keys[k] to = diffusers_keys[k]
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
return key_map return key_map
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype): def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype) weight_calc = weight + function(lora_diff).type(weight.dtype)
weight_norm = ( weight_norm = (
weight_calc.transpose(0, 1) weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1) .reshape(weight_calc.shape[1], -1)
@@ -400,7 +415,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
weight *= strength_model weight *= strength_model
if isinstance(v, list): if isinstance(v, list):
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), ) v = (calculate_weight(v[1:], comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype, copy=True), key, intermediate_dtype=intermediate_dtype), )
if len(v) == 1: if len(v) == 1:
patch_type = "diff" patch_type = "diff"
@@ -438,7 +453,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
try: try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None: if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else: else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e: except Exception as e:
@@ -484,7 +499,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
try: try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape) lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None: if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else: else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e: except Exception as e:
@@ -521,28 +536,48 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
try: try:
lora_diff = (m1 * m2).reshape(weight.shape) lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None: if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else: else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora": elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
dora_scale = v[5] dora_scale = v[5]
old_glora = False
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
rank = v[0].shape[0]
old_glora = True
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
pass
else:
old_glora = False
rank = v[1].shape[0]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
if v[4] is not None:
alpha = v[4] / rank
else:
alpha = 1.0
try: try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
else:
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
if dora_scale is not None: if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else: else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e: except Exception as e:

View File

@@ -96,7 +96,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None: if model_config.custom_operations is None:
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype) operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=model_config.optimizations.get("fp8", False))
else: else:
operations = model_config.custom_operations operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)

View File

@@ -145,7 +145,7 @@ total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try: try:
logging.info("pytorch version: {}".format(torch.version.__version__)) logging.info("pytorch version: {}".format(torch_version))
except: except:
pass pass
@@ -326,7 +326,7 @@ class LoadedModel:
self.model_unload() self.model_unload()
raise e raise e
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None: if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
with torch.no_grad(): with torch.no_grad():
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True) self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
@@ -426,7 +426,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
shift_model = current_loaded_models[i] shift_model = current_loaded_models[i]
if shift_model.device == device: if shift_model.device == device:
if shift_model not in keep_loaded: if shift_model not in keep_loaded:
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False shift_model.currently_used = False
for x in sorted(can_unload): for x in sorted(can_unload):
@@ -626,6 +626,8 @@ def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.88 - minimum_inference_memory()) return (get_total_memory(device) * 0.88 - minimum_inference_memory())
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if model_params < 0:
model_params = 1000000000000000000000
if args.bf16_unet: if args.bf16_unet:
return torch.bfloat16 return torch.bfloat16
if args.fp16_unet: if args.fp16_unet:
@@ -897,7 +899,7 @@ def force_upcast_attention_dtype():
upcast = args.force_upcast_attention upcast = args.force_upcast_attention
try: try:
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split(".")) macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS if (14, 5) <= macos_version <= (15, 0, 1): # black image bug on recent versions of macOS
upcast = True upcast = True
except: except:
pass pass
@@ -1063,6 +1065,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False return False
def supports_fp8_compute(device=None): def supports_fp8_compute(device=None):
if not is_nvidia():
return False
props = torch.cuda.get_device_properties(device) props = torch.cuda.get_device_properties(device)
if props.major >= 9: if props.major >= 9:
return True return True
@@ -1070,6 +1075,14 @@ def supports_fp8_compute(device=None):
return False return False
if props.minor < 9: if props.minor < 9:
return False return False
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
return False
if WINDOWS:
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
return False
return True return True
def soft_empty_cache(force=False): def soft_empty_cache(force=False):

View File

@@ -28,7 +28,7 @@ import comfy.utils
import comfy.float import comfy.float
import comfy.model_management import comfy.model_management
import comfy.lora import comfy.lora
from comfy.types import UnetWrapperFunction from comfy.comfy_types import UnetWrapperFunction
def string_to_seed(data): def string_to_seed(data):
crc = 0xFFFFFFFF crc = 0xFFFFFFFF
@@ -88,8 +88,12 @@ class LowVramPatch:
self.key = key self.key = key
self.patches = patches self.patches = patches
def __call__(self, weight): def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) intermediate_dtype = weight.dtype
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size self.size = size
@@ -283,17 +287,21 @@ class ModelPatcher:
return list(p) return list(p)
def get_key_patches(self, filter_prefix=None): def get_key_patches(self, filter_prefix=None):
comfy.model_management.unload_model_clones(self)
model_sd = self.model_state_dict() model_sd = self.model_state_dict()
p = {} p = {}
for k in model_sd: for k in model_sd:
if filter_prefix is not None: if filter_prefix is not None:
if not k.startswith(filter_prefix): if not k.startswith(filter_prefix):
continue continue
if k in self.patches: bk = self.backup.get(k, None)
p[k] = [model_sd[k]] + self.patches[k] if bk is not None:
weight = bk.weight
else: else:
p[k] = (model_sd[k],) weight = model_sd[k]
if k in self.patches:
p[k] = [weight] + self.patches[k]
else:
p[k] = (weight,)
return p return p
def model_state_dict(self, filter_prefix=None): def model_state_dict(self, filter_prefix=None):

View File

@@ -260,7 +260,6 @@ def fp8_linear(self, input):
if len(input.shape) == 3: if len(input.shape) == 3:
inn = input.reshape(-1, input.shape[2]).to(dtype) inn = input.reshape(-1, input.shape[2]).to(dtype)
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype) w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t() w = w.t()
@@ -300,10 +299,14 @@ class fp8_ops(manual_cast):
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False):
if comfy.model_management.supports_fp8_compute(load_device):
if (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops
if compute_dtype is None or weight_dtype == compute_dtype: if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init return disable_weight_init
if args.fast: if args.fast and not disable_fast_fp8:
if comfy.model_management.supports_fp8_compute(load_device): if comfy.model_management.supports_fp8_compute(load_device):
return fp8_ops return fp8_ops
return manual_cast return manual_cast

View File

@@ -6,7 +6,7 @@ from comfy import model_management
import math import math
import logging import logging
import comfy.sampler_helpers import comfy.sampler_helpers
import scipy import scipy.stats
import numpy import numpy
def get_area_and_mult(conds, x_in, timestep_in): def get_area_and_mult(conds, x_in, timestep_in):
@@ -570,8 +570,8 @@ class Sampler:
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis"] "ipndm", "ipndm_v", "deis"]
class KSAMPLER(Sampler): class KSAMPLER(Sampler):

View File

@@ -29,7 +29,6 @@ import comfy.text_encoders.long_clipl
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
import comfy.supported_models_base
import comfy.taesd.taesd import comfy.taesd.taesd
def load_lora_for_models(model, clip, lora, strength_model, strength_clip): def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
@@ -70,14 +69,14 @@ class CLIP:
clip = target.clip clip = target.clip
tokenizer = target.tokenizer tokenizer = target.tokenizer
load_device = model_management.text_encoder_device() load_device = model_options.get("load_device", model_management.text_encoder_device())
offload_device = model_management.text_encoder_offload_device() offload_device = model_options.get("offload_device", model_management.text_encoder_offload_device())
dtype = model_options.get("dtype", None) dtype = model_options.get("dtype", None)
if dtype is None: if dtype is None:
dtype = model_management.text_encoder_dtype(load_device) dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype params['dtype'] = dtype
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)) params['device'] = model_options.get("initial_device", model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)))
params['model_options'] = model_options params['model_options'] = model_options
self.cond_stage_model = clip(**(params)) self.cond_stage_model = clip(**(params))
@@ -348,7 +347,7 @@ class VAE:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used) model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device) samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
@@ -406,8 +405,35 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
class TEModel(Enum):
CLIP_L = 1
CLIP_H = 2
CLIP_G = 3
T5_XXL = 4
T5_XL = 5
T5_BASE = 6
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
return TEModel.CLIP_G
if "text_model.encoder.layers.22.mlp.fc1.weight" in sd:
return TEModel.CLIP_H
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
return TEModel.CLIP_L
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
if weight.shape[-1] == 4096:
return TEModel.T5_XXL
elif weight.shape[-1] == 2048:
return TEModel.T5_XL
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
return TEModel.T5_BASE
return None
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = state_dicts clip_data = state_dicts
class EmptyClass: class EmptyClass:
pass pass
@@ -421,39 +447,42 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target = EmptyClass() clip_target = EmptyClass()
clip_target.params = {} clip_target.params = {}
if len(clip_data) == 1: if len(clip_data) == 1:
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]: te_model = detect_te_model(clip_data[0])
if te_model == TEModel.CLIP_G:
if clip_type == CLIPType.STABLE_CASCADE: if clip_type == CLIPType.STABLE_CASCADE:
clip_target.clip = sdxl_clip.StableCascadeClipModel clip_target.clip = sdxl_clip.StableCascadeClipModel
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
elif clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
else: else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: elif te_model == TEModel.CLIP_H:
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]: elif te_model == TEModel.T5_XXL:
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
dtype_t5 = weight.dtype dtype_t5 = weight.dtype
if weight.shape[-1] == 4096: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif te_model == TEModel.T5_XL:
elif weight.shape[-1] == 2048: clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer elif te_model == TEModel.T5_BASE:
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
else: else:
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None) if clip_type == CLIPType.SD3:
if w is not None and w.shape[0] == 248: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer
else: else:
clip_target.clip = sd1_clip.SD1ClipModel clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2: elif len(clip_data) == 2:
if clip_type == CLIPType.SD3: if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False) te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HUNYUAN_DIT: elif clip_type == CLIPType.HUNYUAN_DIT:
clip_target.clip = comfy.text_encoders.hydit.HyditModel clip_target.clip = comfy.text_encoders.hydit.HyditModel
@@ -475,10 +504,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
parameters = 0 parameters = 0
tokenizer_data = {}
for c in clip_data: for c in clip_data:
parameters += comfy.utils.calculate_parameters(c) parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options) clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
for c in clip_data: for c in clip_data:
m, u = clip.load_sd(c) m, u = clip.load_sd(c)
if len(m) > 0: if len(m) > 0:
@@ -562,7 +593,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model: if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model.load_model_weights(sd, diffusion_model_prefix) model.load_model_weights(sd, diffusion_model_prefix)
@@ -647,7 +677,10 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", None) model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if model_options.get("fp8_optimizations", False):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
model = model.to(offload_device) model = model.to(offload_device)
model.load_model_weights(new_sd, "") model.load_model_weights(new_sd, "")

View File

@@ -542,6 +542,7 @@ class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)) setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -570,6 +571,7 @@ class SD1ClipModel(torch.nn.Module):
self.clip_name = clip_name self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs)) setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
self.dtypes = set() self.dtypes = set()

View File

@@ -22,7 +22,8 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
class SDXLTokenizer: class SDXLTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -40,7 +41,8 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module): class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__() super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options) clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options) self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set([dtype]) self.dtypes = set([dtype])
@@ -57,7 +59,8 @@ class SDXLClipModel(torch.nn.Module):
token_weight_pairs_l = token_weight_pairs["l"] token_weight_pairs_l = token_weight_pairs["l"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return torch.cat([l_out, g_out], dim=-1), g_pooled cut_to = min(l_out.shape[1], g_out.shape[1])
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
def load_sd(self, sd): def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:

View File

@@ -49,6 +49,7 @@ class BASE:
manual_cast_dtype = None manual_cast_dtype = None
custom_operations = None custom_operations = None
optimizations = {"fp8": False}
@classmethod @classmethod
def matches(s, unet_config, state_dict=None): def matches(s, unet_config, state_dict=None):

View File

@@ -13,12 +13,13 @@ class T5XXLModel(sd1_clip.SDClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
class FluxTokenizer: class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -38,7 +39,8 @@ class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}): def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__() super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options) self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5]) self.dtypes = set([dtype, dtype_t5])

View File

@@ -6,9 +6,9 @@ class LongClipTokenizer_(sd1_clip.SDTokenizer):
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
class LongClipModel_(sd1_clip.SDClipModel): class LongClipModel_(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, *args, **kwargs):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options) super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
class LongClipTokenizer(sd1_clip.SD1Tokenizer): class LongClipTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -17,3 +17,14 @@ class LongClipTokenizer(sd1_clip.SD1Tokenizer):
class LongClipModel(sd1_clip.SD1ClipModel): class LongClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs) super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
def model_options_long_clip(sd, tokenizer_data, model_options):
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
if w is None:
w = sd.get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
tokenizer_data = tokenizer_data.copy()
model_options = model_options.copy()
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
model_options["clip_l_class"] = LongClipModel_
return tokenizer_data, model_options

View File

@@ -20,7 +20,8 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
class SD3Tokenizer: class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
@@ -42,7 +43,8 @@ class SD3ClipModel(torch.nn.Module):
super().__init__() super().__init__()
self.dtypes = set() self.dtypes = set()
if clip_l: if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options) clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.dtypes.add(dtype) self.dtypes.add(dtype)
else: else:
self.clip_l = None self.clip_l = None
@@ -95,7 +97,8 @@ class SD3ClipModel(torch.nn.Module):
if self.clip_g is not None: if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
if lg_out is not None: if lg_out is not None:
lg_out = torch.cat([lg_out, g_out], dim=-1) cut_to = min(lg_out.shape[1], g_out.shape[1])
lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
else: else:
lg_out = torch.nn.functional.pad(g_out, (768, 0)) lg_out = torch.nn.functional.pad(g_out, (768, 0))
else: else:

View File

@@ -713,7 +713,9 @@ def common_upscale(samples, width, height, upscale_method, crop):
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
return rows * cols
@torch.inference_mode() @torch.inference_mode()
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
@@ -722,10 +724,20 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
for b in range(samples.shape[0]): for b in range(samples.shape[0]):
s = samples[b:b+1] s = samples[b:b+1]
# handle entire input fitting in a single tile
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
output[b:b+1] = function(s).to(output_device)
if pbar is not None:
pbar.update(1)
continue
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))): positions = [range(0, s.shape[d+2], tile[d] - overlap) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
for it in itertools.product(*positions):
s_in = s s_in = s
upscaled = [] upscaled = []
@@ -734,15 +746,16 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
l = min(tile[d], s.shape[d + 2] - pos) l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l) s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(pos * upscale_amount)) upscaled.append(round(pos * upscale_amount))
ps = function(s_in).to(output_device) ps = function(s_in).to(output_device)
mask = torch.ones_like(ps) mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount) feather = round(overlap * upscale_amount)
for t in range(feather): for t in range(feather):
for d in range(2, dims + 2): for d in range(2, dims + 2):
m = mask.narrow(d, t, 1) a = (t + 1) / feather
m *= ((1.0/feather) * (t + 1)) mask.narrow(d, t, 1).mul_(a)
m = mask.narrow(d, mask.shape[d] -1 -t, 1) mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
m *= ((1.0/feather) * (t + 1))
o = out o = out
o_d = out_div o_d = out_div
@@ -750,8 +763,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o += ps * mask o.add_(ps * mask)
o_d += mask o_d.add_(mask)
if pbar is not None: if pbar is not None:
pbar.update(1) pbar.update(1)

View File

@@ -1,11 +1,21 @@
import itertools import itertools
from typing import Sequence, Mapping from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt from comfy_execution.graph import DynamicPrompt
import nodes import nodes
from comfy_execution.graph_utils import is_link from comfy_execution.graph_utils import is_link
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
def include_unique_id_in_input(class_type: str) -> bool:
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet: class CacheKeySet:
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed_cache):
self.keys = {} self.keys = {}
@@ -98,7 +108,7 @@ class CacheKeySetInputSignature(CacheKeySet):
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)] signature = [class_type, self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT): if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id) signature.append(node_id)
inputs = node["inputs"] inputs = node["inputs"]
for key in sorted(inputs.keys()): for key in sorted(inputs.keys()):

View File

@@ -99,30 +99,44 @@ class TopologicalSort:
self.add_strong_link(from_node_id, from_socket, to_node_id) self.add_strong_link(from_node_id, from_socket, to_node_id)
def add_strong_link(self, from_node_id, from_socket, to_node_id): def add_strong_link(self, from_node_id, from_socket, to_node_id):
self.add_node(from_node_id) if not self.is_cached(from_node_id):
if to_node_id not in self.blocking[from_node_id]: self.add_node(from_node_id)
self.blocking[from_node_id][to_node_id] = {} if to_node_id not in self.blocking[from_node_id]:
self.blockCount[to_node_id] += 1 self.blocking[from_node_id][to_node_id] = {}
self.blocking[from_node_id][to_node_id][from_socket] = True self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None): def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
if unique_id in self.pendingNodes: node_ids = [node_unique_id]
return links = []
self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}
inputs = self.dynprompt.get_node(unique_id)["inputs"] while len(node_ids) > 0:
for input_name in inputs: unique_id = node_ids.pop()
value = inputs[input_name] if unique_id in self.pendingNodes:
if is_link(value): continue
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes: self.pendingNodes[unique_id] = True
continue self.blockCount[unique_id] = 0
input_type, input_category, input_info = self.get_input_info(unique_id, input_name) self.blocking[unique_id] = {}
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if include_lazy or not is_lazy: inputs = self.dynprompt.get_node(unique_id)["inputs"]
self.add_strong_link(from_node_id, from_socket, unique_id) for input_name in inputs:
value = inputs[input_name]
if is_link(value):
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id))
for link in links:
self.add_strong_link(*link)
def is_cached(self, node_id):
return False
def get_ready_nodes(self): def get_ready_nodes(self):
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
@@ -146,11 +160,8 @@ class ExecutionList(TopologicalSort):
self.output_cache = output_cache self.output_cache = output_cache
self.staged_node_id = None self.staged_node_id = None
def add_strong_link(self, from_node_id, from_socket, to_node_id): def is_cached(self, node_id):
if self.output_cache.get(from_node_id) is not None: return self.output_cache.get(node_id) is not None
# Nothing to do
return
super().add_strong_link(from_node_id, from_socket, to_node_id)
def stage_node_execution(self): def stage_node_execution(self):
assert self.staged_node_id is None assert self.staged_node_id is None

View File

@@ -16,14 +16,15 @@ class EmptyLatentAudio:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}} return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "generate" FUNCTION = "generate"
CATEGORY = "latent/audio" CATEGORY = "latent/audio"
def generate(self, seconds): def generate(self, seconds, batch_size):
batch_size = 1
length = round((seconds * 44100 / 2048) / 2) * 2 length = round((seconds * 44100 / 2048) / 2) * 2
latent = torch.zeros([batch_size, 64, length], device=self.device) latent = torch.zeros([batch_size, 64, length], device=self.device)
return ({"samples":latent, "type": "audio"}, ) return ({"samples":latent, "type": "audio"}, )
@@ -58,6 +59,9 @@ class VAEDecodeAudio:
def decode(self, vae, samples): def decode(self, vae, samples):
audio = vae.decode(samples["samples"]).movedim(-1, 1) audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
return ({"waveform": audio, "sample_rate": 44100}, ) return ({"waveform": audio, "sample_rate": 44100}, )
@@ -183,17 +187,10 @@ class PreviewAudio(SaveAudio):
} }
class LoadAudio: class LoadAudio:
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory() input_dir = folder_paths.get_input_directory()
files = [ files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
f for f in os.listdir(input_dir)
if (os.path.isfile(os.path.join(input_dir, f))
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
)
]
return {"required": {"audio": (sorted(files), {"audio_upload": True})}} return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
CATEGORY = "audio" CATEGORY = "audio"

View File

@@ -1,4 +1,6 @@
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
import nodes
import comfy.utils
class SetUnionControlNetType: class SetUnionControlNetType:
@classmethod @classmethod
@@ -22,6 +24,37 @@ class SetUnionControlNetType:
return (control_net,) return (control_net,)
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"vae": ("VAE", ),
"image": ("IMAGE", ),
"mask": ("MASK", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
FUNCTION = "apply_inpaint_controlnet"
CATEGORY = "conditioning/controlnet"
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
extra_concat = []
if control_net.concat_mask:
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
extra_concat = [mask]
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"SetUnionControlNetType": SetUnionControlNetType, "SetUnionControlNetType": SetUnionControlNetType,
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
} }

View File

@@ -90,6 +90,27 @@ class PolyexponentialScheduler:
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
return (sigmas, ) return (sigmas, )
class LaplaceScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
"beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
return (sigmas, )
class SDTurboScheduler: class SDTurboScheduler:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@@ -673,6 +694,7 @@ NODE_CLASS_MAPPINGS = {
"KarrasScheduler": KarrasScheduler, "KarrasScheduler": KarrasScheduler,
"ExponentialScheduler": ExponentialScheduler, "ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler, "PolyexponentialScheduler": PolyexponentialScheduler,
"LaplaceScheduler": LaplaceScheduler,
"VPScheduler": VPScheduler, "VPScheduler": VPScheduler,
"BetaSamplingScheduler": BetaSamplingScheduler, "BetaSamplingScheduler": BetaSamplingScheduler,
"SDTurboScheduler": SDTurboScheduler, "SDTurboScheduler": SDTurboScheduler,

View File

@@ -107,7 +107,7 @@ class HypernetworkLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength): def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone() model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength) patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None: if patch is not None:

View File

@@ -0,0 +1,115 @@
import torch
import comfy.model_management
import comfy.utils
import folder_paths
import os
import logging
from enum import Enum
CLAMP_QUANTILE = 0.99
def extract_lora(diff, rank):
conv2d = (len(diff.shape) == 4)
kernel_size = None if not conv2d else diff.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = diff.size()[0:2]
rank = min(rank, in_dim, out_dim)
if conv2d:
if conv2d_3x3:
diff = diff.flatten(start_dim=1)
else:
diff = diff.squeeze()
U, S, Vh = torch.linalg.svd(diff.float())
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
return (U, Vh)
class LORAType(Enum):
STANDARD = 0
FULL_DIFF = 1
LORA_TYPES = {"standard": LORAType.STANDARD,
"full_diff": LORAType.FULL_DIFF}
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
for k in sd:
if k.endswith(".weight"):
weight_diff = sd[k]
if lora_type == LORAType.STANDARD:
if weight_diff.ndim < 2:
if bias_diff:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
continue
try:
out = extract_lora(weight_diff, rank)
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
except:
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
elif lora_type == LORAType.FULL_DIFF:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
elif bias_diff and k.endswith(".bias"):
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
return output_sd
class LoraSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
"lora_type": (tuple(LORA_TYPES.keys()),),
"bias_diff": ("BOOLEAN", {"default": True}),
},
"optional": {"model_diff": ("MODEL",),
"text_encoder_diff": ("CLIP",)},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
if model_diff is None and text_encoder_diff is None:
return {}
lora_type = LORA_TYPES.get(lora_type)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
output_sd = {}
if model_diff is not None:
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
if text_encoder_diff is not None:
output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
return {}
NODE_CLASS_MAPPINGS = {
"LoraSave": LoraSave
}

View File

@@ -17,7 +17,7 @@ class PatchModelAddDownscale:
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "_for_testing" CATEGORY = "model_patches/unet"
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method): def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
model_sampling = model.get_model_object("model_sampling") model_sampling = model.get_model_object("model_sampling")

View File

@@ -26,6 +26,7 @@ class PerpNeg:
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
DEPRECATED = True
def patch(self, model, empty_conditioning, neg_scale): def patch(self, model, empty_conditioning, neg_scale):
m = model.clone() m = model.clone()

View File

@@ -126,7 +126,7 @@ class PhotoMakerLoader:
CATEGORY = "_for_testing/photomaker" CATEGORY = "_for_testing/photomaker"
def load_photomaker_model(self, photomaker_model_name): def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name) photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder() photomaker_model = PhotoMakerIDEncoder()
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data: if "id_encoder" in data:

View File

@@ -15,9 +15,9 @@ class TripleCLIPLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2, clip_name3): def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = folder_paths.get_full_path("clip", clip_name1) clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2) clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
clip_path3 = folder_paths.get_full_path("clip", clip_name3) clip_path3 = folder_paths.get_full_path_or_raise("clip", clip_name3)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,) return (clip,)
@@ -36,7 +36,7 @@ class EmptySD3LatentImage:
CATEGORY = "latent/sd3" CATEGORY = "latent/sd3"
def generate(self, width, height, batch_size=1): def generate(self, width, height, batch_size=1):
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609 latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
return ({"samples":latent}, ) return ({"samples":latent}, )
class CLIPTextEncodeSD3: class CLIPTextEncodeSD3:
@@ -93,6 +93,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}} }}
CATEGORY = "conditioning/controlnet" CATEGORY = "conditioning/controlnet"
DEPRECATED = True
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"TripleCLIPLoader": TripleCLIPLoader, "TripleCLIPLoader": TripleCLIPLoader,
@@ -103,5 +104,5 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling # Sampling
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT", "ControlNetApplySD3": "Apply Controlnet with VAE",
} }

View File

@@ -116,6 +116,7 @@ class StableCascade_SuperResolutionControlnet:
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b") RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
FUNCTION = "generate" FUNCTION = "generate"
EXPERIMENTAL = True
CATEGORY = "_for_testing/stable_cascade" CATEGORY = "_for_testing/stable_cascade"
def generate(self, image, vae): def generate(self, image, vae):

View File

@@ -154,7 +154,7 @@ class TomePatchModel:
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "_for_testing" CATEGORY = "model_patches/unet"
def patch(self, model, ratio): def patch(self, model, ratio):
self.u = None self.u = None

View File

@@ -0,0 +1,22 @@
import torch
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"backend": (["inductor", "cudagraphs"],),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model, backend):
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
return (m, )
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}

View File

@@ -25,7 +25,7 @@ class UpscaleModelLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_model(self, model_name): def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name) model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True) sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""}) sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})

View File

@@ -17,7 +17,7 @@ class ImageOnlyCheckpointLoader:
CATEGORY = "loaders/video_models" CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2]) return (out[0], out[3], out[2])
@@ -107,7 +107,7 @@ class VideoTriangleCFGGuidance:
return (m, ) return (m, )
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave): class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
CATEGORY = "_for_testing" CATEGORY = "advanced/model_merging"
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):

View File

@@ -37,6 +37,7 @@ class SaveImageWebsocket:
return {} return {}
@classmethod
def IS_CHANGED(s, images): def IS_CHANGED(s, images):
return time.time() return time.time()

View File

@@ -179,7 +179,13 @@ def merge_result_data(results, obj):
# merge node execution results # merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list): for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list: if is_list:
output.append([x for o in results for x in o[i]]) value = []
for o in results:
if isinstance(o[i], ExecutionBlocker):
value.append(o[i])
else:
value.extend(o[i])
output.append(value)
else: else:
output.append([o[i] for o in results]) output.append([o[i] for o in results])
return output return output

View File

@@ -25,11 +25,16 @@ a111:
#comfyui: #comfyui:
# base_path: path/to/comfyui/ # base_path: path/to/comfyui/
# # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
# #is_default: true
# checkpoints: models/checkpoints/ # checkpoints: models/checkpoints/
# clip: models/clip/ # clip: models/clip/
# clip_vision: models/clip_vision/ # clip_vision: models/clip_vision/
# configs: models/configs/ # configs: models/configs/
# controlnet: models/controlnet/ # controlnet: models/controlnet/
# diffusion_models: |
# models/diffusion_models
# models/unet
# embeddings: models/embeddings/ # embeddings: models/embeddings/
# loras: models/loras/ # loras: models/loras/
# upscale_models: models/upscale_models/ # upscale_models: models/upscale_models/

View File

@@ -2,7 +2,9 @@ from __future__ import annotations
import os import os
import time import time
import mimetypes
import logging import logging
from typing import Set, List, Dict, Tuple, Literal
from collections.abc import Collection from collections.abc import Collection
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'} supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
@@ -44,6 +46,40 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {} filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
class CacheHelper:
"""
Helper class for managing file list cache data.
"""
def __init__(self):
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
self.active = False
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
if not self.active:
return default
return self.cache.get(key, default)
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
if self.active:
self.cache[key] = value
def clear(self):
self.cache.clear()
def __enter__(self):
self.active = True
return self
def __exit__(self, exc_type, exc_value, traceback):
self.active = False
self.clear()
cache_helper = CacheHelper()
extension_mimetypes_cache = {
"webp" : "image",
}
def map_legacy(folder_name: str) -> str: def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"} legacy = {"unet": "diffusion_models"}
return legacy.get(folder_name, folder_name) return legacy.get(folder_name, folder_name)
@@ -78,6 +114,13 @@ def get_input_directory() -> str:
global input_directory global input_directory
return input_directory return input_directory
def get_user_directory() -> str:
return user_directory
def set_user_directory(user_dir: str) -> None:
global user_directory
user_directory = user_dir
#NOTE: used in http server so don't put folders that should not be accessed remotely #NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name: str) -> str | None: def get_directory_by_type(type_name: str) -> str | None:
@@ -89,6 +132,28 @@ def get_directory_by_type(type_name: str) -> str | None:
return get_input_directory() return get_input_directory()
return None return None
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
"""
Example:
files = os.listdir(folder_paths.get_input_directory())
filter_files_content_types(files, ["image", "audio", "video"])
"""
global extension_mimetypes_cache
result = []
for file in files:
extension = file.split('.')[-1]
if extension not in extension_mimetypes_cache:
mime_type, _ = mimetypes.guess_type(file, strict=False)
if not mime_type:
continue
content_type = mime_type.split('/')[0]
extension_mimetypes_cache[extension] = content_type
else:
content_type = extension_mimetypes_cache[extension]
if content_type in content_types:
result.append(file)
return result
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir # otherwise use default_path as base_dir
@@ -130,11 +195,14 @@ def exists_annotated_filepath(name) -> bool:
return os.path.exists(filepath) return os.path.exists(filepath)
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None: def add_model_folder_path(folder_name: str, full_folder_path: str, is_default: bool = False) -> None:
global folder_names_and_paths global folder_names_and_paths
folder_name = map_legacy(folder_name) folder_name = map_legacy(folder_name)
if folder_name in folder_names_and_paths: if folder_name in folder_names_and_paths:
folder_names_and_paths[folder_name][0].append(full_folder_path) if is_default:
folder_names_and_paths[folder_name][0].insert(0, full_folder_path)
else:
folder_names_and_paths[folder_name][0].append(full_folder_path)
else: else:
folder_names_and_paths[folder_name] = ([full_folder_path], set()) folder_names_and_paths[folder_name] = ([full_folder_path], set())
@@ -166,8 +234,12 @@ def recursive_search(directory: str, excluded_dir_names: list[str] | None=None)
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames: for file_name in filenames:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) try:
result.append(relative_path) relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
continue
for d in subdirs: for d in subdirs:
path: str = os.path.join(dirpath, d) path: str = os.path.join(dirpath, d)
@@ -200,6 +272,14 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
return None return None
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
full_path = get_full_path(folder_name, filename)
if full_path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
return full_path
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]: def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name) folder_name = map_legacy(folder_name)
global folder_names_and_paths global folder_names_and_paths
@@ -214,6 +294,10 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
return sorted(list(output_list)), output_folders, time.perf_counter() return sorted(list(output_list)), output_folders, time.perf_counter()
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None: def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
strong_cache = cache_helper.get(folder_name)
if strong_cache is not None:
return strong_cache
global filename_list_cache global filename_list_cache
global folder_names_and_paths global folder_names_and_paths
folder_name = map_legacy(folder_name) folder_name = map_legacy(folder_name)
@@ -242,6 +326,7 @@ def get_filename_list(folder_name: str) -> list[str]:
out = get_filename_list_(folder_name) out = get_filename_list_(folder_name)
global filename_list_cache global filename_list_cache
filename_list_cache[folder_name] = out filename_list_cache[folder_name] = out
cache_helper.set(folder_name, out)
return list(out[0]) return list(out[0])
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]: def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
@@ -257,9 +342,17 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
def compute_vars(input: str, image_width: int, image_height: int) -> str: def compute_vars(input: str, image_width: int, image_height: int) -> str:
input = input.replace("%width%", str(image_width)) input = input.replace("%width%", str(image_width))
input = input.replace("%height%", str(image_height)) input = input.replace("%height%", str(image_height))
now = time.localtime()
input = input.replace("%year%", str(now.tm_year))
input = input.replace("%month%", str(now.tm_mon).zfill(2))
input = input.replace("%day%", str(now.tm_mday).zfill(2))
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
input = input.replace("%minute%", str(now.tm_min).zfill(2))
input = input.replace("%second%", str(now.tm_sec).zfill(2))
return input return input
filename_prefix = compute_vars(filename_prefix, image_width, image_height) if "%" in filename_prefix:
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
subfolder = os.path.dirname(os.path.normpath(filename_prefix)) subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix)) filename = os.path.basename(os.path.normpath(filename_prefix))

View File

@@ -9,7 +9,7 @@ import folder_paths
import comfy.utils import comfy.utils
import logging import logging
MAX_PREVIEW_RESOLUTION = 512 MAX_PREVIEW_RESOLUTION = args.preview_size
def preview_to_image(latent_image): def preview_to_image(latent_image):
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
@@ -36,12 +36,20 @@ class TAESDPreviewerImpl(LatentPreviewer):
class Latent2RGBPreviewer(LatentPreviewer): class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self, latent_rgb_factors): def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu") self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
self.latent_rgb_factors_bias = None
if latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
def decode_latent_to_preview(self, x0): def decode_latent_to_preview(self, x0):
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
return preview_to_image(latent_image) return preview_to_image(latent_image)
@@ -71,7 +79,7 @@ def get_previewer(device, latent_format):
if previewer is None: if previewer is None:
if latent_format.latent_rgb_factors is not None: if latent_format.latent_rgb_factors is not None:
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors) previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
return previewer return previewer
def prepare_callback(model, steps, x0_output_dict=None): def prepare_callback(model, steps, x0_output_dict=None):

43
main.py
View File

@@ -9,7 +9,7 @@ from comfy.cli_args import args
from app.logger import setup_logger from app.logger import setup_logger
setup_logger(verbose=args.verbose) setup_logger(log_level=args.verbose)
def execute_prestartup_script(): def execute_prestartup_script():
@@ -63,6 +63,7 @@ import threading
import gc import gc
import logging import logging
import utils.extra_config
if os.name == "nt": if os.name == "nt":
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
@@ -85,7 +86,6 @@ if args.windows_standalone_build:
pass pass
import comfy.utils import comfy.utils
import yaml
import execution import execution
import server import server
@@ -160,7 +160,10 @@ def prompt_worker(q, server):
need_gc = False need_gc = False
async def run(server, address='', port=8188, verbose=True, call_on_start=None): async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) addresses = []
for addr in address.split(","):
addresses.append((addr, port))
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
def hijack_progress(server): def hijack_progress(server):
@@ -180,27 +183,6 @@ def cleanup_temp():
shutil.rmtree(temp_dir, ignore_errors=True) shutil.rmtree(temp_dir, ignore_errors=True)
def load_extra_path_config(yaml_path):
with open(yaml_path, 'r') as stream:
config = yaml.safe_load(stream)
for c in config:
conf = config[c]
if conf is None:
continue
base_path = None
if "base_path" in conf:
base_path = conf.pop("base_path")
for x in conf:
for y in conf[x].split("\n"):
if len(y) == 0:
continue
full_path = y
if base_path is not None:
full_path = os.path.join(base_path, full_path)
logging.info("Adding extra search path {} {}".format(x, full_path))
folder_paths.add_model_folder_path(x, full_path)
if __name__ == "__main__": if __name__ == "__main__":
if args.temp_directory: if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
@@ -222,11 +204,11 @@ if __name__ == "__main__":
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path): if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path) utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
if args.extra_model_paths_config: if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config): for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path) utils.extra_config.load_extra_path_config(config_path)
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes) nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
@@ -247,21 +229,30 @@ if __name__ == "__main__":
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models")) folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
if args.input_directory: if args.input_directory:
input_dir = os.path.abspath(args.input_directory) input_dir = os.path.abspath(args.input_directory)
logging.info(f"Setting input directory to: {input_dir}") logging.info(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir) folder_paths.set_input_directory(input_dir)
if args.user_directory:
user_dir = os.path.abspath(args.user_directory)
logging.info(f"Setting user directory to: {user_dir}")
folder_paths.set_user_directory(user_dir)
if args.quick_test_for_ci: if args.quick_test_for_ci:
exit(0) exit(0)
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
call_on_start = None call_on_start = None
if args.auto_launch: if args.auto_launch:
def startup_server(scheme, address, port): def startup_server(scheme, address, port):
import webbrowser import webbrowser
if os.name == 'nt' and address == '0.0.0.0': if os.name == 'nt' and address == '0.0.0.0':
address = '127.0.0.1' address = '127.0.0.1'
if ':' in address:
address = "[{}]".format(address)
webbrowser.open(f"{scheme}://{address}:{port}") webbrowser.open(f"{scheme}://{address}:{port}")
call_on_start = startup_server call_on_start = startup_server

View File

@@ -1,2 +1,2 @@
# model_manager/__init__.py # model_manager/__init__.py
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename

View File

@@ -1,9 +1,10 @@
#NOTE: This was an experiment and WILL BE REMOVED
from __future__ import annotations from __future__ import annotations
import aiohttp import aiohttp
import os import os
import traceback import traceback
import logging import logging
from folder_paths import models_dir from folder_paths import folder_names_and_paths, get_folder_paths
import re import re
from typing import Callable, Any, Optional, Awaitable, Dict from typing import Callable, Any, Optional, Awaitable, Dict
from enum import Enum from enum import Enum
@@ -17,6 +18,7 @@ class DownloadStatusType(Enum):
COMPLETED = "completed" COMPLETED = "completed"
ERROR = "error" ERROR = "error"
@dataclass @dataclass
class DownloadModelStatus(): class DownloadModelStatus():
status: str status: str
@@ -29,7 +31,7 @@ class DownloadModelStatus():
self.progress_percentage = progress_percentage self.progress_percentage = progress_percentage
self.message = message self.message = message
self.already_existed = already_existed self.already_existed = already_existed
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
"status": self.status, "status": self.status,
@@ -38,102 +40,112 @@ class DownloadModelStatus():
"already_existed": self.already_existed "already_existed": self.already_existed
} }
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str, model_name: str,
model_url: str, model_url: str,
model_sub_directory: str, model_directory: str,
folder_path: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus: progress_interval: float = 1.0) -> DownloadModelStatus:
""" """
Download a model file from a given URL into the models directory. Download a model file from a given URL into the models directory.
Args: Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]): model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests. A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str): model_name (str):
The name of the model file to be downloaded. This will be the filename on disk. The name of the model file to be downloaded. This will be the filename on disk.
model_url (str): model_url (str):
The URL from which to download the model. The URL from which to download the model.
model_sub_directory (str): model_directory (str):
The subdirectory within the main models directory where the model The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.). should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates. An asynchronous function to call with progress updates.
folder_path (str);
Path to which model folder should be used as the root.
Returns: Returns:
DownloadModelStatus: The result of the download operation. DownloadModelStatus: The result of the download operation.
""" """
if not validate_model_subdirectory(model_sub_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model subdirectory",
False
)
if not validate_filename(model_name): if not validate_filename(model_name):
return DownloadModelStatus( return DownloadModelStatus(
DownloadStatusType.ERROR, DownloadStatusType.ERROR,
0, 0,
"Invalid model name", "Invalid model name",
False False
) )
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir) if not model_directory in folder_names_and_paths:
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path) return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
False
)
if not folder_path in get_folder_paths(model_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
False
)
file_path = create_model_path(model_name, folder_path)
existing_file = await check_file_exists(file_path, model_name, progress_callback)
if existing_file: if existing_file:
return existing_file return existing_file
try: try:
logging.info(f"Downloading {model_name} from {model_url}")
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False) status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
response = await model_download_request(model_url) response = await model_download_request(model_url)
if response.status != 200: if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}" error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message) logging.error(error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval) return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
except Exception as e: except Exception as e:
logging.error(f"Error in downloading model: {e}") logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback, relative_path) return await handle_download_error(e, model_name, progress_callback)
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
full_model_dir = os.path.join(models_base_dir, model_directory) def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
os.makedirs(full_model_dir, exist_ok=True) os.makedirs(folder_path, exist_ok=True)
file_path = os.path.join(full_model_dir, model_name) file_path = os.path.join(folder_path, model_name)
# Ensure the resulting path is still within the base directory # Ensure the resulting path is still within the base directory
abs_file_path = os.path.abspath(file_path) abs_file_path = os.path.abspath(file_path)
abs_base_dir = os.path.abspath(str(models_base_dir)) abs_base_dir = os.path.abspath(folder_path)
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir: if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
raise Exception(f"Invalid model directory: {model_directory}/{model_name}") raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
return file_path
relative_path = '/'.join([model_directory, model_name]) async def check_file_exists(file_path: str,
return file_path, relative_path model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
async def check_file_exists(file_path: str, ) -> Optional[DownloadModelStatus]:
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path): if os.path.exists(file_path):
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
return status return status
return None return None
async def track_download_progress(response: aiohttp.ClientResponse, async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str, file_path: str,
model_name: str, model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str,
interval: float = 1.0) -> DownloadModelStatus: interval: float = 1.0) -> DownloadModelStatus:
try: try:
total_size = int(response.headers.get('Content-Length', 0)) total_size = int(response.headers.get('Content-Length', 0))
@@ -144,10 +156,11 @@ async def track_download_progress(response: aiohttp.ClientResponse,
nonlocal last_update_time nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0 progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False) status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
last_update_time = time.time() last_update_time = time.time()
with open(file_path, 'wb') as f: temp_file_path = file_path + '.tmp'
with open(temp_file_path, 'wb') as f:
chunk_iterator = response.content.iter_chunked(8192) chunk_iterator = response.content.iter_chunked(8192)
while True: while True:
try: try:
@@ -156,58 +169,39 @@ async def track_download_progress(response: aiohttp.ClientResponse,
break break
f.write(chunk) f.write(chunk)
downloaded += len(chunk) downloaded += len(chunk)
if time.time() - last_update_time >= interval: if time.time() - last_update_time >= interval:
await update_progress() await update_progress()
os.rename(temp_file_path, file_path)
await update_progress() await update_progress()
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
return status return status
except Exception as e: except Exception as e:
logging.error(f"Error in track_download_progress: {e}") logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback, relative_path) return await handle_download_error(e, model_name, progress_callback)
async def handle_download_error(e: Exception,
model_name: str, async def handle_download_error(e: Exception,
progress_callback: Callable[[str, DownloadModelStatus], Any], model_name: str,
relative_path: str) -> DownloadModelStatus: progress_callback: Callable[[str, DownloadModelStatus], Any]
) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}" error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
return status return status
def validate_model_subdirectory(model_subdirectory: str) -> bool:
"""
Validate that the model subdirectory is safe to install into.
Must not contain relative paths, nested paths or special characters
other than underscores and hyphens.
Args:
model_subdirectory (str): The subdirectory for the specific model type.
Returns:
bool: True if the subdirectory is safe, False otherwise.
"""
if len(model_subdirectory) > 50:
return False
if '..' in model_subdirectory or '/' in model_subdirectory:
return False
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
return False
return True
def validate_filename(filename: str)-> bool: def validate_filename(filename: str)-> bool:
""" """
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts. Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
Args: Args:
filename (str): The filename to validate filename (str): The filename to validate

View File

@@ -511,10 +511,11 @@ class CheckpointLoader:
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
DEPRECATED = True
def load_checkpoint(self, config_name, ckpt_name): def load_checkpoint(self, config_name, ckpt_name):
config_path = folder_paths.get_full_path("configs", config_name) config_path = folder_paths.get_full_path("configs", config_name)
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class CheckpointLoaderSimple: class CheckpointLoaderSimple:
@@ -535,7 +536,7 @@ class CheckpointLoaderSimple:
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents." DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
def load_checkpoint(self, ckpt_name): def load_checkpoint(self, ckpt_name):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out[:3] return out[:3]
@@ -577,7 +578,7 @@ class unCLIPCheckpointLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out return out
@@ -624,7 +625,7 @@ class LoraLoader:
if strength_model == 0 and strength_clip == 0: if strength_model == 0 and strength_clip == 0:
return (model, clip) return (model, clip)
lora_path = folder_paths.get_full_path("loras", lora_name) lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None lora = None
if self.loaded_lora is not None: if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path: if self.loaded_lora[0] == lora_path:
@@ -703,11 +704,11 @@ class VAELoader:
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes)) encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes)) decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder)) enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
for k in enc: for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k] sd["taesd_encoder.{}".format(k)] = enc[k]
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder)) dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
for k in dec: for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k] sd["taesd_decoder.{}".format(k)] = dec[k]
@@ -738,7 +739,7 @@ class VAELoader:
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
sd = self.load_taesd(vae_name) sd = self.load_taesd(vae_name)
else: else:
vae_path = folder_paths.get_full_path("vae", vae_name) vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd) vae = comfy.sd.VAE(sd=sd)
return (vae,) return (vae,)
@@ -754,7 +755,7 @@ class ControlNetLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_controlnet(self, control_net_name): def load_controlnet(self, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path) controlnet = comfy.controlnet.load_controlnet(controlnet_path)
return (controlnet,) return (controlnet,)
@@ -770,7 +771,7 @@ class DiffControlNetLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_controlnet(self, model, control_net_name): def load_controlnet(self, model, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model) controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
return (controlnet,) return (controlnet,)
@@ -786,6 +787,7 @@ class ControlNetApply:
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_controlnet" FUNCTION = "apply_controlnet"
DEPRECATED = True
CATEGORY = "conditioning/controlnet" CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, conditioning, control_net, image, strength): def apply_controlnet(self, conditioning, control_net, image, strength):
@@ -815,7 +817,10 @@ class ControlNetApplyAdvanced:
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}} },
"optional": {"vae": ("VAE", ),
}
}
RETURN_TYPES = ("CONDITIONING","CONDITIONING") RETURN_TYPES = ("CONDITIONING","CONDITIONING")
RETURN_NAMES = ("positive", "negative") RETURN_NAMES = ("positive", "negative")
@@ -823,7 +828,7 @@ class ControlNetApplyAdvanced:
CATEGORY = "conditioning/controlnet" CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None): def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
if strength == 0: if strength == 0:
return (positive, negative) return (positive, negative)
@@ -840,7 +845,7 @@ class ControlNetApplyAdvanced:
if prev_cnet in cnets: if prev_cnet in cnets:
c_net = cnets[prev_cnet] c_net = cnets[prev_cnet]
else: else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae) c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat)
c_net.set_previous_controlnet(prev_cnet) c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net cnets[prev_cnet] = c_net
@@ -856,7 +861,7 @@ class UNETLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],) "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet" FUNCTION = "load_unet"
@@ -867,10 +872,13 @@ class UNETLoader:
model_options = {} model_options = {}
if weight_dtype == "fp8_e4m3fn": if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e4m3fn_fast":
model_options["dtype"] = torch.float8_e4m3fn
model_options["fp8_optimizations"] = True
elif weight_dtype == "fp8_e5m2": elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2 model_options["dtype"] = torch.float8_e5m2
unet_path = folder_paths.get_full_path("diffusion_models", unet_name) unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,) return (model,)
@@ -895,7 +903,7 @@ class CLIPLoader:
else: else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
clip_path = folder_paths.get_full_path("clip", clip_name) clip_path = folder_paths.get_full_path_or_raise("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,) return (clip,)
@@ -912,8 +920,8 @@ class DualCLIPLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2, type): def load_clip(self, clip_name1, clip_name2, type):
clip_path1 = folder_paths.get_full_path("clip", clip_name1) clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2) clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
if type == "sdxl": if type == "sdxl":
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3": elif type == "sd3":
@@ -935,7 +943,7 @@ class CLIPVisionLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_clip(self, clip_name): def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip_vision", clip_name) clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
clip_vision = comfy.clip_vision.load(clip_path) clip_vision = comfy.clip_vision.load(clip_path)
return (clip_vision,) return (clip_vision,)
@@ -965,7 +973,7 @@ class StyleModelLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_style_model(self, style_model_name): def load_style_model(self, style_model_name):
style_model_path = folder_paths.get_full_path("style_models", style_model_name) style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
style_model = comfy.sd.load_style_model(style_model_path) style_model = comfy.sd.load_style_model(style_model_path)
return (style_model,) return (style_model,)
@@ -1030,7 +1038,7 @@ class GLIGENLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_gligen(self, gligen_name): def load_gligen(self, gligen_name):
gligen_path = folder_paths.get_full_path("gligen", gligen_name) gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
gligen = comfy.sd.load_gligen(gligen_path) gligen = comfy.sd.load_gligen(gligen_path)
return (gligen,) return (gligen,)
@@ -1916,8 +1924,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetArea": "Conditioning (Set Area)",
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)", "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
"ConditioningSetMask": "Conditioning (Set Mask)", "ConditioningSetMask": "Conditioning (Set Mask)",
"ControlNetApply": "Apply ControlNet", "ControlNetApply": "Apply ControlNet (OLD)",
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)", "ControlNetApplyAdvanced": "Apply ControlNet",
# Latent # Latent
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)", "VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
"SetLatentNoiseMask": "Set Latent Noise Mask", "SetLatentNoiseMask": "Set Latent Noise Mask",
@@ -2101,6 +2109,8 @@ def init_builtin_extra_nodes():
"nodes_controlnet.py", "nodes_controlnet.py",
"nodes_hunyuan.py", "nodes_hunyuan.py",
"nodes_flux.py", "nodes_flux.py",
"nodes_lora_extract.py",
"nodes_torch_compile.py",
] ]
import_failed = [] import_failed = []

View File

@@ -38,6 +38,9 @@ def get_images(ws, prompt):
if data['node'] is None and data['prompt_id'] == prompt_id: if data['node'] is None and data['prompt_id'] == prompt_id:
break #Execution is done break #Execution is done
else: else:
# If you want to be able to decode the binary stream for latent previews, here is how you can do it:
# bytesIO = BytesIO(out[8:])
# preview_image = Image.open(bytesIO) # This is your preview in PIL image format, store it in a global
continue #previews are binary data continue #previews are binary data
history = get_history(prompt_id)[prompt_id] history = get_history(prompt_id)[prompt_id]
@@ -151,7 +154,7 @@ prompt["3"]["inputs"]["seed"] = 5
ws = websocket.WebSocket() ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, prompt) images = get_images(ws, prompt)
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
#Commented out code to display the output images: #Commented out code to display the output images:
# for node_id in images: # for node_id in images:

View File

@@ -147,7 +147,7 @@ prompt["3"]["inputs"]["seed"] = 5
ws = websocket.WebSocket() ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, prompt) images = get_images(ws, prompt)
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
#Commented out code to display the output images: #Commented out code to display the output images:
# for node_id in images: # for node_id in images:

128
server.py
View File

@@ -12,6 +12,8 @@ import json
import glob import glob
import struct import struct
import ssl import ssl
import socket
import ipaddress
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from io import BytesIO from io import BytesIO
@@ -80,6 +82,68 @@ def create_cors_middleware(allowed_origin: str):
return cors_middleware return cors_middleware
def is_loopback(host):
if host is None:
return False
try:
if ipaddress.ip_address(host).is_loopback:
return True
else:
return False
except:
pass
loopback = False
for family in (socket.AF_INET, socket.AF_INET6):
try:
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
for family, _, _, _, sockaddr in r:
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
return loopback
else:
loopback = True
except socket.gaierror:
pass
return loopback
def create_origin_only_middleware():
@web.middleware
async def origin_only_middleware(request: web.Request, handler):
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
#in that case the Host and Origin hostnames won't match
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
if 'Host' in request.headers and 'Origin' in request.headers:
host = request.headers['Host']
origin = request.headers['Origin']
host_domain = host.lower()
parsed = urllib.parse.urlparse(origin)
origin_domain = parsed.netloc.lower()
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
loopback = is_loopback(host_domain_parsed.hostname)
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
host_domain = host_domain_parsed.hostname
if host_domain_parsed.port is None:
origin_domain = parsed.hostname
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
if host_domain != origin_domain:
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
return web.Response(status=403)
if request.method == "OPTIONS":
response = web.Response()
else:
response = await handler(request)
return response
return origin_only_middleware
class PromptServer(): class PromptServer():
def __init__(self, loop): def __init__(self, loop):
PromptServer.instance = self PromptServer.instance = self
@@ -99,6 +163,8 @@ class PromptServer():
middlewares = [cache_control] middlewares = [cache_control]
if args.enable_cors_header: if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header)) middlewares.append(create_cors_middleware(args.enable_cors_header))
else:
middlewares.append(create_origin_only_middleware())
max_upload_size = round(args.max_upload_size * 1024 * 1024) max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares) self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
@@ -155,6 +221,12 @@ class PromptServer():
def get_embeddings(self): def get_embeddings(self):
embeddings = folder_paths.get_filename_list("embeddings") embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings))) return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@routes.get("/models")
def list_model_types(request):
model_types = list(folder_paths.folder_names_and_paths.keys())
return web.json_response(model_types)
@routes.get("/models/{folder}") @routes.get("/models/{folder}")
async def get_models(request): async def get_models(request):
@@ -418,12 +490,17 @@ class PromptServer():
async def system_stats(request): async def system_stats(request):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device) device_name = comfy.model_management.get_torch_device_name(device)
cpu_device = comfy.model_management.torch.device("cpu")
ram_total = comfy.model_management.get_total_memory(cpu_device)
ram_free = comfy.model_management.get_free_memory(cpu_device)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
system_stats = { system_stats = {
"system": { "system": {
"os": os.name, "os": os.name,
"ram_total": ram_total,
"ram_free": ram_free,
"comfyui_version": get_comfyui_version(), "comfyui_version": get_comfyui_version(),
"python_version": sys.version, "python_version": sys.version,
"pytorch_version": comfy.model_management.torch_version, "pytorch_version": comfy.model_management.torch_version,
@@ -480,14 +557,15 @@ class PromptServer():
@routes.get("/object_info") @routes.get("/object_info")
async def get_object_info(request): async def get_object_info(request):
out = {} with folder_paths.cache_helper:
for x in nodes.NODE_CLASS_MAPPINGS: out = {}
try: for x in nodes.NODE_CLASS_MAPPINGS:
out[x] = node_info(x) try:
except Exception as e: out[x] = node_info(x)
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") except Exception as e:
logging.error(traceback.format_exc()) logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
return web.json_response(out) logging.error(traceback.format_exc())
return web.json_response(out)
@routes.get("/object_info/{node_class}") @routes.get("/object_info/{node_class}")
async def get_object_info_node(request): async def get_object_info_node(request):
@@ -601,6 +679,7 @@ class PromptServer():
# Internal route. Should not be depended upon and is subject to change at any time. # Internal route. Should not be depended upon and is subject to change at any time.
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket. # TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
# NOTE: This was an experiment and WILL BE REMOVED
@routes.post("/internal/models/download") @routes.post("/internal/models/download")
async def download_handler(request): async def download_handler(request):
async def report_progress(filename: str, status: DownloadModelStatus): async def report_progress(filename: str, status: DownloadModelStatus):
@@ -611,10 +690,11 @@ class PromptServer():
data = await request.json() data = await request.json()
url = data.get('url') url = data.get('url')
model_directory = data.get('model_directory') model_directory = data.get('model_directory')
folder_path = data.get('folder_path')
model_filename = data.get('model_filename') model_filename = data.get('model_filename')
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress. progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
if not url or not model_directory or not model_filename: if not url or not model_directory or not model_filename or not folder_path:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400) return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
session = self.client_session session = self.client_session
@@ -622,7 +702,7 @@ class PromptServer():
logging.error("Client session is not initialized") logging.error("Client session is not initialized")
return web.Response(status=500) return web.Response(status=500)
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval)) task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
await task await task
return web.json_response(task.result().to_dict()) return web.json_response(task.result().to_dict())
@@ -739,6 +819,9 @@ class PromptServer():
await self.send(*msg) await self.send(*msg)
async def start(self, address, port, verbose=True, call_on_start=None): async def start(self, address, port, verbose=True, call_on_start=None):
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
async def start_multi_address(self, addresses, call_on_start=None):
runner = web.AppRunner(self.app, access_log=None) runner = web.AppRunner(self.app, access_log=None)
await runner.setup() await runner.setup()
ssl_ctx = None ssl_ctx = None
@@ -749,17 +832,26 @@ class PromptServer():
keyfile=args.tls_keyfile) keyfile=args.tls_keyfile)
scheme = "https" scheme = "https"
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx) logging.info("Starting server\n")
await site.start() for addr in addresses:
address = addr[0]
port = addr[1]
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
await site.start()
self.address = address if not hasattr(self, 'address'):
self.port = port self.address = address #TODO: remove this
self.port = port
if ':' in address:
address_print = "[{}]".format(address)
else:
address_print = address
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
if verbose:
logging.info("Starting server\n")
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
if call_on_start is not None: if call_on_start is not None:
call_on_start(scheme, address, port) call_on_start(scheme, self.address, self.port)
def add_on_prompt_handler(self, handler): def add_on_prompt_handler(self, handler):
self.on_prompt_handlers.append(handler) self.on_prompt_handlers.append(handler)

View File

@@ -2,7 +2,7 @@
## Install test dependencies ## Install test dependencies
`pip install -r tests-units/requirements.txt` `pip install -r tests-unit/requirements.txt`
## Run tests ## Run tests
`pytest tests-units/` `pytest tests-unit/`

View File

@@ -0,0 +1,66 @@
### 🗻 This file is created through the spirit of Mount Fuji at its peak
# TODO(yoland): clean up this after I get back down
import pytest
import os
import tempfile
from unittest.mock import patch
import folder_paths
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield tmpdirname
def test_get_directory_by_type():
test_dir = "/test/dir"
folder_paths.set_output_directory(test_dir)
assert folder_paths.get_directory_by_type("output") == test_dir
assert folder_paths.get_directory_by_type("invalid") is None
def test_annotated_filepath():
assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None)
assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory())
assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory())
assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory())
def test_get_annotated_filepath():
default_dir = "/default/dir"
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
def test_add_model_folder_path():
folder_paths.add_model_folder_path("test_folder", "/test/path")
assert "/test/path" in folder_paths.get_folder_paths("test_folder")
def test_recursive_search(temp_dir):
os.makedirs(os.path.join(temp_dir, "subdir"))
open(os.path.join(temp_dir, "file1.txt"), "w").close()
open(os.path.join(temp_dir, "subdir", "file2.txt"), "w").close()
files, dirs = folder_paths.recursive_search(temp_dir)
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
assert len(dirs) == 2 # temp_dir and subdir
def test_filter_files_extensions():
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"]
assert folder_paths.filter_files_extensions(files, []) == files
@patch("folder_paths.recursive_search")
@patch("folder_paths.folder_names_and_paths")
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
mock_folder_names_and_paths.__getitem__.return_value = (["/test/path"], {".txt"})
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
def test_get_save_image_path(temp_dir):
with patch("folder_paths.output_directory", temp_dir):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
assert os.path.samefile(full_output_folder, temp_dir)
assert filename == "test"
assert counter == 1
assert subfolder == ""
assert filename_prefix == "test"

View File

View File

@@ -0,0 +1,52 @@
import pytest
import os
import tempfile
from folder_paths import filter_files_content_types
@pytest.fixture(scope="module")
def file_extensions():
return {
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
}
@pytest.fixture(scope="module")
def mock_dir(file_extensions):
with tempfile.TemporaryDirectory() as directory:
for content_type, extensions in file_extensions.items():
for extension in extensions:
with open(f"{directory}/sample_{content_type}.{extension}", "w") as f:
f.write(f"Sample {content_type} file in {extension} format")
yield directory
def test_categorizes_all_correctly(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
for extension in extensions:
assert f"sample_{content_type}.{extension}" in filtered_files
def test_categorizes_all_uniquely(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
assert len(filtered_files) == len(extensions)
def test_handles_bad_extensions():
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
def test_handles_no_extension():
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
def test_handles_no_files():
files = []
assert filter_files_content_types(files, ["image", "audio", "video"]) == []

View File

@@ -1,10 +1,17 @@
import pytest import pytest
import tempfile
import aiohttp import aiohttp
from aiohttp import ClientResponse from aiohttp import ClientResponse
import itertools import itertools
import os import os
from unittest.mock import AsyncMock, patch, MagicMock from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
import folder_paths
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield tmpdirname
class AsyncIteratorMock: class AsyncIteratorMock:
""" """
@@ -42,7 +49,7 @@ class ContentMock:
return AsyncIteratorMock(self.chunks) return AsyncIteratorMock(self.chunks)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_model_success(): async def test_download_model_success(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.status = 200 mock_response.status = 200
mock_response.headers = {'Content-Length': '1000'} mock_response.headers = {'Content-Length': '1000'}
@@ -53,15 +60,13 @@ async def test_download_model_success():
mock_make_request = AsyncMock(return_value=mock_response) mock_make_request = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock() mock_progress_callback = AsyncMock()
# Mock file operations
mock_open = MagicMock()
mock_file = MagicMock()
mock_open.return_value.__enter__.return_value = mock_file
time_values = itertools.count(0, 0.1) time_values = itertools.count(0, 0.1)
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \ fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
patch('model_filemanager.check_file_exists', return_value=None), \ patch('model_filemanager.check_file_exists', return_value=None), \
patch('builtins.open', mock_open), \ patch('folder_paths.folder_names_and_paths', fake_paths), \
patch('time.time', side_effect=time_values): # Simulate time passing patch('time.time', side_effect=time_values): # Simulate time passing
result = await download_model( result = await download_model(
@@ -69,6 +74,7 @@ async def test_download_model_success():
'model.sft', 'model.sft',
'http://example.com/model.sft', 'http://example.com/model.sft',
'checkpoints', 'checkpoints',
temp_dir,
mock_progress_callback mock_progress_callback
) )
@@ -83,44 +89,48 @@ async def test_download_model_success():
# Check initial call # Check initial call
mock_progress_callback.assert_any_call( mock_progress_callback.assert_any_call(
'checkpoints/model.sft', 'model.sft',
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False) DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
) )
# Check final call # Check final call
mock_progress_callback.assert_any_call( mock_progress_callback.assert_any_call(
'checkpoints/model.sft', 'model.sft',
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False) DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
) )
# Verify file writing mock_file_path = os.path.join(temp_dir, 'model.sft')
mock_file.write.assert_any_call(b'a' * 500) assert os.path.exists(mock_file_path)
mock_file.write.assert_any_call(b'b' * 300) with open(mock_file_path, 'rb') as mock_file:
mock_file.write.assert_any_call(b'c' * 200) assert mock_file.read() == b''.join(chunks)
os.remove(mock_file_path)
# Verify request was made # Verify request was made
mock_make_request.assert_called_once_with('http://example.com/model.sft') mock_make_request.assert_called_once_with('http://example.com/model.sft')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_model_url_request_failure(): async def test_download_model_url_request_failure(temp_dir):
# Mock dependencies # Mock dependencies
mock_response = AsyncMock(spec=ClientResponse) mock_response = AsyncMock(spec=ClientResponse)
mock_response.status = 404 # Simulate a "Not Found" error mock_response.status = 404 # Simulate a "Not Found" error
mock_get = AsyncMock(return_value=mock_response) mock_get = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock() mock_progress_callback = AsyncMock()
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
# Mock the create_model_path function # Mock the create_model_path function
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')): with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
# Mock the check_file_exists function to return None (file doesn't exist) patch('model_filemanager.check_file_exists', return_value=None), \
with patch('model_filemanager.check_file_exists', return_value=None): patch('folder_paths.folder_names_and_paths', fake_paths):
# Call the function # Call the function
result = await download_model( result = await download_model(
mock_get, mock_get,
'model.safetensors', 'model.safetensors',
'http://example.com/model.safetensors', 'http://example.com/model.safetensors',
'mock_directory', 'checkpoints',
mock_progress_callback temp_dir,
) mock_progress_callback
)
# Assert the expected behavior # Assert the expected behavior
assert isinstance(result, DownloadModelStatus) assert isinstance(result, DownloadModelStatus)
@@ -130,7 +140,7 @@ async def test_download_model_url_request_failure():
# Check that progress_callback was called with the correct arguments # Check that progress_callback was called with the correct arguments
mock_progress_callback.assert_any_call( mock_progress_callback.assert_any_call(
'mock_directory/model.safetensors', 'model.safetensors',
DownloadModelStatus( DownloadModelStatus(
status=DownloadStatusType.PENDING, status=DownloadStatusType.PENDING,
progress_percentage=0, progress_percentage=0,
@@ -139,7 +149,7 @@ async def test_download_model_url_request_failure():
) )
) )
mock_progress_callback.assert_called_with( mock_progress_callback.assert_called_with(
'mock_directory/model.safetensors', 'model.safetensors',
DownloadModelStatus( DownloadModelStatus(
status=DownloadStatusType.ERROR, status=DownloadStatusType.ERROR,
progress_percentage=0, progress_percentage=0,
@@ -153,98 +163,125 @@ async def test_download_model_url_request_failure():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_model_invalid_model_subdirectory(): async def test_download_model_invalid_model_subdirectory():
mock_make_request = AsyncMock() mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock() mock_progress_callback = AsyncMock()
result = await download_model( result = await download_model(
mock_make_request, mock_make_request,
'model.sft', 'model.sft',
'http://example.com/model.sft', 'http://example.com/model.sft',
'../bad_path', '../bad_path',
'../bad_path',
mock_progress_callback mock_progress_callback
) )
# Assert the result # Assert the result
assert isinstance(result, DownloadModelStatus) assert isinstance(result, DownloadModelStatus)
assert result.message == 'Invalid model subdirectory' assert result.message.startswith('Invalid or unrecognized model directory')
assert result.status == 'error' assert result.status == 'error'
assert result.already_existed is False assert result.already_existed is False
@pytest.mark.asyncio
async def test_download_model_invalid_folder_path():
mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock()
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'checkpoints',
'invalid_path',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message.startswith("Invalid folder path")
assert result.status == 'error'
assert result.already_existed is False
# For create_model_path function
def test_create_model_path(tmp_path, monkeypatch): def test_create_model_path(tmp_path, monkeypatch):
mock_models_dir = tmp_path / "models" model_name = "model.safetensors"
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir)) folder_path = os.path.join(tmp_path, "mock_dir")
model_name = "test_model.sft" file_path = create_model_path(model_name, folder_path)
model_directory = "test_dir"
assert file_path == os.path.join(folder_path, "model.safetensors")
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
assert file_path == str(mock_models_dir / model_directory / model_name)
assert relative_path == f"{model_directory}/{model_name}"
assert os.path.exists(os.path.dirname(file_path)) assert os.path.exists(os.path.dirname(file_path))
with pytest.raises(Exception, match="Invalid model directory"):
create_model_path("../path_traversal.safetensors", folder_path)
with pytest.raises(Exception, match="Invalid model directory"):
create_model_path("/etc/some_root_path", folder_path)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_file_exists_when_file_exists(tmp_path): async def test_check_file_exists_when_file_exists(tmp_path):
file_path = tmp_path / "existing_model.sft" file_path = tmp_path / "existing_model.sft"
file_path.touch() # Create an empty file file_path.touch() # Create an empty file
mock_callback = AsyncMock() mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft") result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback)
assert result is not None assert result is not None
assert result.status == "completed" assert result.status == "completed"
assert result.message == "existing_model.sft already exists" assert result.message == "existing_model.sft already exists"
assert result.already_existed is True assert result.already_existed is True
mock_callback.assert_called_once_with( mock_callback.assert_called_once_with(
"test/existing_model.sft", "existing_model.sft",
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True) DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_file_exists_when_file_does_not_exist(tmp_path): async def test_check_file_exists_when_file_does_not_exist(tmp_path):
file_path = tmp_path / "non_existing_model.sft" file_path = tmp_path / "non_existing_model.sft"
mock_callback = AsyncMock() mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft") result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback)
assert result is None assert result is None
mock_callback.assert_not_called() mock_callback.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_track_download_progress_no_content_length(): async def test_track_download_progress_no_content_length(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {} # No Content-Length header mock_response.headers = {} # No Content-Length header
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500]) chunks = [b'a' * 500, b'b' * 500]
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
mock_callback = AsyncMock() mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
with patch('builtins.open', mock_open): full_path = os.path.join(temp_dir, 'model.sft')
result = await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft', result = await track_download_progress(
mock_callback, 'models/model.sft', interval=0.1 mock_response, full_path, 'model.sft',
) mock_callback, interval=0.1
)
assert result.status == "completed" assert result.status == "completed"
assert os.path.exists(full_path)
with open(full_path, 'rb') as f:
assert f.read() == b''.join(chunks)
os.remove(full_path)
# Check that progress was reported even without knowing the total size # Check that progress was reported even without knowing the total size
mock_callback.assert_any_call( mock_callback.assert_any_call(
'models/model.sft', 'model.sft',
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False) DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_track_download_progress_interval(): async def test_track_download_progress_interval(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {'Content-Length': '1000'} mock_response.headers = {'Content-Length': '1000'}
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10) chunks = [b'a' * 100] * 10
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
mock_callback = AsyncMock() mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock()) mock_open = MagicMock(return_value=MagicMock())
@@ -253,18 +290,18 @@ async def test_track_download_progress_interval():
mock_time = MagicMock() mock_time = MagicMock()
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
with patch('builtins.open', mock_open), \ full_path = os.path.join(temp_dir, 'model.sft')
patch('time.time', mock_time):
await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, 'models/model.sft', interval=1.0
)
# Print out the actual call count and the arguments of each call for debugging with patch('time.time', mock_time):
print(f"mock_callback was called {mock_callback.call_count} times") await track_download_progress(
for i, call in enumerate(mock_callback.call_args_list): mock_response, full_path, 'model.sft',
args, kwargs = call mock_callback, interval=1.0
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%") )
assert os.path.exists(full_path)
with open(full_path, 'rb') as f:
assert f.read() == b''.join(chunks)
os.remove(full_path)
# Assert that progress was updated at least 3 times (start, at least one interval, and end) # Assert that progress was updated at least 3 times (start, at least one interval, and end)
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}" assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
@@ -279,27 +316,6 @@ async def test_track_download_progress_interval():
assert last_call[0][1].status == "completed" assert last_call[0][1].status == "completed"
assert last_call[0][1].progress_percentage == 100 assert last_call[0][1].progress_percentage == 100
def test_valid_subdirectory():
assert validate_model_subdirectory("valid-model123") is True
def test_subdirectory_too_long():
assert validate_model_subdirectory("a" * 51) is False
def test_subdirectory_with_double_dots():
assert validate_model_subdirectory("model/../unsafe") is False
def test_subdirectory_with_slash():
assert validate_model_subdirectory("model/unsafe") is False
def test_subdirectory_with_special_characters():
assert validate_model_subdirectory("model@unsafe") is False
def test_subdirectory_with_underscore_and_dash():
assert validate_model_subdirectory("valid_model-name") is True
def test_empty_subdirectory():
assert validate_model_subdirectory("") is False
@pytest.mark.parametrize("filename, expected", [ @pytest.mark.parametrize("filename, expected", [
("valid_model.safetensors", True), ("valid_model.safetensors", True),
("valid_model.sft", True), ("valid_model.sft", True),

View File

@@ -0,0 +1,120 @@
import pytest
import os
from aiohttp import web
from app.user_manager import UserManager
from unittest.mock import patch
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def user_manager(tmp_path):
um = UserManager()
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
tmp_path, file
)
return um
@pytest.fixture
def app(user_manager):
app = web.Application()
routes = web.RouteTableDef()
user_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 404
async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 200
assert await resp.json() == ["file1.txt"]
async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true")
assert resp.status == 200
assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"}
async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&full_info=true")
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert result[0]["path"] == "file1.txt"
assert "size" in result[0]
assert "modified" in result[0]
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
assert resp.status == 200
assert await resp.json() == [
["subdir/file1.txt", "subdir", "file1.txt"]
]
async def test_listuserdata_invalid_directory(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=")
assert resp.status == 400
async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
os_sep = "\\"
with patch("os.sep", os_sep):
with patch("os.path.sep", os_sep):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true")
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert "/" in result[0] # Ensure forward slash is used
assert "\\" not in result[0] # Ensure backslash is not present
assert result[0] == "subdir/file1.txt"
# Test with full_info
resp = await client.get(
"/userdata?dir=test_dir&recurse=true&full_info=true"
)
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert "/" in result[0]["path"] # Ensure forward slash is used
assert "\\" not in result[0]["path"] # Ensure backslash is not present
assert result[0]["path"] == "subdir/file1.txt"

View File

@@ -0,0 +1,126 @@
import pytest
import yaml
import os
from unittest.mock import Mock, patch, mock_open
from utils.extra_config import load_extra_path_config
import folder_paths
@pytest.fixture
def mock_yaml_content():
return {
'test_config': {
'base_path': '~/App/',
'checkpoints': 'subfolder1',
}
}
@pytest.fixture
def mock_expanded_home():
return '/home/user'
@pytest.fixture
def yaml_config_with_appdata():
return """
test_config:
base_path: '%APPDATA%/ComfyUI'
checkpoints: 'models/checkpoints'
"""
@pytest.fixture
def mock_yaml_content_appdata(yaml_config_with_appdata):
return yaml.safe_load(yaml_config_with_appdata)
@pytest.fixture
def mock_expandvars_appdata():
mock = Mock()
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
return mock
@pytest.fixture
def mock_add_model_folder_path():
return Mock()
@pytest.fixture
def mock_expanduser(mock_expanded_home):
def _expanduser(path):
if path.startswith('~/'):
return os.path.join(mock_expanded_home, path[2:])
return path
return _expanduser
@pytest.fixture
def mock_yaml_safe_load(mock_yaml_content):
return Mock(return_value=mock_yaml_content)
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
def test_load_extra_model_paths_expands_userpath(
mock_file,
monkeypatch,
mock_add_model_folder_path,
mock_expanduser,
mock_yaml_safe_load,
mock_expanded_home
):
# Attach mocks used by load_extra_path_config
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
monkeypatch.setattr(os.path, 'expanduser', mock_expanduser)
monkeypatch.setattr(yaml, 'safe_load', mock_yaml_safe_load)
dummy_yaml_file_name = 'dummy_path.yaml'
load_extra_path_config(dummy_yaml_file_name)
expected_calls = [
('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1'), False),
]
assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check if add_model_folder_path was called with the correct arguments
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args[0] == expected_call[0]
assert os.path.normpath(actual_call.args[1]) == os.path.normpath(expected_call[1]) # Normalize and check the path to check on multiple OS.
assert actual_call.args[2] == expected_call[2]
# Check if yaml.safe_load was called
mock_yaml_safe_load.assert_called_once()
# Check if open was called with the correct file path
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
@patch('builtins.open', new_callable=mock_open)
def test_load_extra_model_paths_expands_appdata(
mock_file,
monkeypatch,
mock_add_model_folder_path,
mock_expandvars_appdata,
yaml_config_with_appdata,
mock_yaml_content_appdata
):
# Set the mock_file to return yaml with appdata as a variable
mock_file.return_value.read.return_value = yaml_config_with_appdata
# Attach mocks
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
monkeypatch.setattr(os.path, 'expandvars', mock_expandvars_appdata)
monkeypatch.setattr(yaml, 'safe_load', Mock(return_value=mock_yaml_content_appdata))
# Mock expanduser to do nothing (since we're not testing it here)
monkeypatch.setattr(os.path, 'expanduser', lambda x: x)
dummy_yaml_file_name = 'dummy_path.yaml'
load_extra_path_config(dummy_yaml_file_name)
expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
expected_calls = [
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False),
]
assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check the base path variable was expanded
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args == expected_call
# Verify that expandvars was called
assert mock_expandvars_appdata.called

View File

@@ -496,3 +496,29 @@ class TestExecution:
assert len(images) == 1, "Should have 1 image" assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
assert not result.did_run(test_node), "The execution should have been cached" assert not result.did_run(test_node), "The execution should have been cached"
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
# only that one entry in the list is blocked.
def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
image3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image_list = g.node("TestMakeListNode", value1=image1.out(0), value2=image2.out(0), value3=image3.out(0))
int1 = g.node("StubInt", value=1)
int2 = g.node("StubInt", value=2)
int3 = g.node("StubInt", value=3)
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
output = g.node("PreviewImage", images=list_output.out(0))
result = client.run(g)
assert result.did_run(output), "The execution should have run"
images = result.get_images(output)
assert len(images) == 2, "Should have 2 images"
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"

0
utils/__init__.py Normal file
View File

28
utils/extra_config.py Normal file
View File

@@ -0,0 +1,28 @@
import os
import yaml
import folder_paths
import logging
def load_extra_path_config(yaml_path):
with open(yaml_path, 'r') as stream:
config = yaml.safe_load(stream)
for c in config:
conf = config[c]
if conf is None:
continue
base_path = None
if "base_path" in conf:
base_path = conf.pop("base_path")
base_path = os.path.expandvars(os.path.expanduser(base_path))
is_default = False
if "is_default" in conf:
is_default = conf.pop("is_default")
for x in conf:
for y in conf[x].split("\n"):
if len(y) == 0:
continue
full_path = y
if base_path is not None:
full_path = os.path.join(base_path, full_path)
logging.info("Adding extra search path {} {}".format(x, full_path))
folder_paths.add_model_folder_path(x, full_path, is_default)

1
web/assets/CREDIT.txt generated vendored Normal file
View File

@@ -0,0 +1 @@
Thanks to OpenArt (https://openart.ai) for providing the sorted-custom-node-map data, captured in September 2024.

792
web/assets/GraphView-BGt8GmeB.css generated vendored Normal file
View File

@@ -0,0 +1,792 @@
.editable-text[data-v-54da6fc9] {
display: inline;
}
.editable-text input[data-v-54da6fc9] {
width: 100%;
box-sizing: border-box;
}
.group-title-editor.node-title-editor[data-v-fc3f26e3] {
z-index: 9999;
padding: 0.25rem;
}
[data-v-fc3f26e3] .editable-text {
width: 100%;
height: 100%;
}
[data-v-fc3f26e3] .editable-text input {
width: 100%;
height: 100%;
/* Override the default font size */
font-size: inherit;
}
.side-bar-button-icon {
font-size: var(--sidebar-icon-size) !important;
}
.side-bar-button-selected .side-bar-button-icon {
font-size: var(--sidebar-icon-size) !important;
font-weight: bold;
}
.side-bar-button[data-v-caa3ee9c] {
width: var(--sidebar-width);
height: var(--sidebar-width);
border-radius: 0;
}
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
border-left: 4px solid var(--p-button-text-primary-color);
}
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
border-right: 4px solid var(--p-button-text-primary-color);
}
:root {
--sidebar-width: 64px;
--sidebar-icon-size: 1.5rem;
}
:root .small-sidebar {
--sidebar-width: 40px;
--sidebar-icon-size: 1rem;
}
.side-tool-bar-container[data-v-4da64512] {
display: flex;
flex-direction: column;
align-items: center;
pointer-events: auto;
width: var(--sidebar-width);
height: 100%;
background-color: var(--comfy-menu-bg);
color: var(--fg-color);
}
.side-tool-bar-end[data-v-4da64512] {
align-self: flex-end;
margin-top: auto;
}
.sidebar-content-container[data-v-4da64512] {
height: 100%;
overflow-y: auto;
}
.p-splitter-gutter {
pointer-events: auto;
}
.gutter-hidden {
display: none !important;
}
.side-bar-panel[data-v-b9df3042] {
background-color: var(--bg-color);
pointer-events: auto;
}
.splitter-overlay[data-v-b9df3042] {
width: 100%;
height: 100%;
position: absolute;
top: 0;
left: 0;
background-color: transparent;
pointer-events: none;
/* Set it the same as the ComfyUI menu */
/* Note: Lite-graph DOM widgets have the same z-index as the node id, so
999 should be sufficient to make sure splitter overlays on node's DOM
widgets */
z-index: 999;
border: none;
}
._content[data-v-e7b35fd9] {
display: flex;
flex-direction: column
}
._content[data-v-e7b35fd9] > :not([hidden]) ~ :not([hidden]) {
--tw-space-y-reverse: 0;
margin-top: calc(0.5rem * calc(1 - var(--tw-space-y-reverse)));
margin-bottom: calc(0.5rem * var(--tw-space-y-reverse))
}
._footer[data-v-e7b35fd9] {
display: flex;
flex-direction: column;
align-items: flex-end;
padding-top: 1rem
}
[data-v-37f672ab] .highlight {
background-color: var(--p-primary-color);
color: var(--p-primary-contrast-color);
font-weight: bold;
border-radius: 0.25rem;
padding: 0rem 0.125rem;
margin: -0.125rem 0.125rem;
}
.slot_row[data-v-ff07c900] {
padding: 2px;
}
/* Original N-Sidebar styles */
._sb_dot[data-v-ff07c900] {
width: 8px;
height: 8px;
border-radius: 50%;
background-color: grey;
}
.node_header[data-v-ff07c900] {
line-height: 1;
padding: 8px 13px 7px;
margin-bottom: 5px;
font-size: 15px;
text-wrap: nowrap;
overflow: hidden;
display: flex;
align-items: center;
}
.headdot[data-v-ff07c900] {
width: 10px;
height: 10px;
float: inline-start;
margin-right: 8px;
}
.IMAGE[data-v-ff07c900] {
background-color: #64b5f6;
}
.VAE[data-v-ff07c900] {
background-color: #ff6e6e;
}
.LATENT[data-v-ff07c900] {
background-color: #ff9cf9;
}
.MASK[data-v-ff07c900] {
background-color: #81c784;
}
.CONDITIONING[data-v-ff07c900] {
background-color: #ffa931;
}
.CLIP[data-v-ff07c900] {
background-color: #ffd500;
}
.MODEL[data-v-ff07c900] {
background-color: #b39ddb;
}
.CONTROL_NET[data-v-ff07c900] {
background-color: #a5d6a7;
}
._sb_node_preview[data-v-ff07c900] {
background-color: var(--comfy-menu-bg);
font-family: 'Open Sans', sans-serif;
font-size: small;
color: var(--descrip-text);
border: 1px solid var(--descrip-text);
min-width: 300px;
width: -moz-min-content;
width: min-content;
height: -moz-fit-content;
height: fit-content;
z-index: 9999;
border-radius: 12px;
overflow: hidden;
font-size: 12px;
padding-bottom: 10px;
}
._sb_node_preview ._sb_description[data-v-ff07c900] {
margin: 10px;
padding: 6px;
background: var(--border-color);
border-radius: 5px;
font-style: italic;
font-weight: 500;
font-size: 0.9rem;
word-break: break-word;
}
._sb_table[data-v-ff07c900] {
display: grid;
grid-column-gap: 10px;
/* Spazio tra le colonne */
width: 100%;
/* Imposta la larghezza della tabella al 100% del contenitore */
}
._sb_row[data-v-ff07c900] {
display: grid;
grid-template-columns: 10px 1fr 1fr 1fr 10px;
grid-column-gap: 10px;
align-items: center;
padding-left: 9px;
padding-right: 9px;
}
._sb_row_string[data-v-ff07c900] {
grid-template-columns: 10px 1fr 1fr 10fr 1fr;
}
._sb_col[data-v-ff07c900] {
border: 0px solid #000;
display: flex;
align-items: flex-end;
flex-direction: row-reverse;
flex-wrap: nowrap;
align-content: flex-start;
justify-content: flex-end;
}
._sb_inherit[data-v-ff07c900] {
display: inherit;
}
._long_field[data-v-ff07c900] {
background: var(--bg-color);
border: 2px solid var(--border-color);
margin: 5px 5px 0 5px;
border-radius: 10px;
line-height: 1.7;
text-wrap: nowrap;
}
._sb_arrow[data-v-ff07c900] {
color: var(--fg-color);
}
._sb_preview_badge[data-v-ff07c900] {
text-align: center;
background: var(--comfy-input-bg);
font-weight: bold;
color: var(--error-text);
}
.comfy-vue-node-search-container[data-v-2d409367] {
display: flex;
width: 100%;
min-width: 26rem;
align-items: center;
justify-content: center;
}
.comfy-vue-node-search-container[data-v-2d409367] * {
pointer-events: auto;
}
.comfy-vue-node-preview-container[data-v-2d409367] {
position: absolute;
left: -350px;
top: 50px;
}
.comfy-vue-node-search-box[data-v-2d409367] {
z-index: 10;
flex-grow: 1;
}
._filter-button[data-v-2d409367] {
z-index: 10;
}
._dialog[data-v-2d409367] {
min-width: 26rem;
}
.invisible-dialog-root {
width: 60%;
min-width: 24rem;
max-width: 48rem;
border: 0 !important;
background-color: transparent !important;
margin-top: 25vh;
margin-left: 400px;
}
@media all and (max-width: 768px) {
.invisible-dialog-root {
margin-left: 0px;
}
}
.node-search-box-dialog-mask {
align-items: flex-start !important;
}
.node-tooltip[data-v-0a4402f9] {
background: var(--comfy-input-bg);
border-radius: 5px;
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
color: var(--input-text);
font-family: sans-serif;
left: 0;
max-width: 30vw;
padding: 4px 8px;
position: absolute;
top: 0;
transform: translate(5px, calc(-100% - 5px));
white-space: pre-wrap;
z-index: 99999;
}
.p-buttongroup-vertical[data-v-ce8bd6ac] {
display: flex;
flex-direction: column;
border-radius: var(--p-button-border-radius);
overflow: hidden;
border: 1px solid var(--p-panel-border-color);
}
.p-buttongroup-vertical .p-button[data-v-ce8bd6ac] {
margin: 0;
border-radius: 0;
}
.comfy-image-wrap[data-v-9bc23daf] {
display: contents;
}
.comfy-image-blur[data-v-9bc23daf] {
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
-o-object-fit: cover;
object-fit: cover;
}
.comfy-image-main[data-v-9bc23daf] {
width: 100%;
height: 100%;
-o-object-fit: cover;
object-fit: cover;
-o-object-position: center;
object-position: center;
z-index: 1;
}
.contain .comfy-image-wrap[data-v-9bc23daf] {
position: relative;
width: 100%;
height: 100%;
}
.contain .comfy-image-main[data-v-9bc23daf] {
-o-object-fit: contain;
object-fit: contain;
-webkit-backdrop-filter: blur(10px);
backdrop-filter: blur(10px);
position: absolute;
}
.broken-image-placeholder[data-v-9bc23daf] {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
width: 100%;
height: 100%;
margin: 2rem;
}
.broken-image-placeholder i[data-v-9bc23daf] {
font-size: 3rem;
margin-bottom: 0.5rem;
}
.result-container[data-v-d9c060ae] {
width: 100%;
height: 100%;
aspect-ratio: 1 / 1;
overflow: hidden;
position: relative;
display: flex;
justify-content: center;
align-items: center;
}
.image-preview-mask[data-v-d9c060ae] {
position: absolute;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
display: flex;
align-items: center;
justify-content: center;
opacity: 0;
transition: opacity 0.3s ease;
z-index: 1;
}
.result-container:hover .image-preview-mask[data-v-d9c060ae] {
opacity: 1;
}
.task-result-preview[data-v-d4c8a1fe] {
aspect-ratio: 1 / 1;
overflow: hidden;
display: flex;
justify-content: center;
align-items: center;
width: 100%;
height: 100%;
}
.task-result-preview i[data-v-d4c8a1fe],
.task-result-preview span[data-v-d4c8a1fe] {
font-size: 2rem;
}
.task-item[data-v-d4c8a1fe] {
display: flex;
flex-direction: column;
border-radius: 4px;
overflow: hidden;
position: relative;
}
.task-item-details[data-v-d4c8a1fe] {
position: absolute;
bottom: 0;
padding: 0.6rem;
display: flex;
justify-content: space-between;
align-items: center;
width: 100%;
z-index: 1;
}
.task-node-link[data-v-d4c8a1fe] {
padding: 2px;
}
/* In dark mode, transparent background color for tags is not ideal for tags that
are floating on top of images. */
.tag-wrapper[data-v-d4c8a1fe] {
background-color: var(--p-primary-contrast-color);
border-radius: 6px;
display: inline-flex;
}
.node-name-tag[data-v-d4c8a1fe] {
word-break: break-all;
}
.status-tag-group[data-v-d4c8a1fe] {
display: flex;
flex-direction: column;
}
.progress-preview-img[data-v-d4c8a1fe] {
width: 100%;
height: 100%;
-o-object-fit: cover;
object-fit: cover;
-o-object-position: center;
object-position: center;
}
/* PrimeVue's galleria teleports the fullscreen gallery out of subtree so we
cannot use scoped style here. */
img.galleria-image {
max-width: 100vw;
max-height: 100vh;
-o-object-fit: contain;
object-fit: contain;
}
.p-galleria-close-button {
/* Set z-index so the close button doesn't get hidden behind the image when image is large */
z-index: 1;
}
.comfy-vue-side-bar-container[data-v-1b0a8fe3] {
display: flex;
flex-direction: column;
height: 100%;
overflow: hidden;
}
.comfy-vue-side-bar-header[data-v-1b0a8fe3] {
flex-shrink: 0;
border-left: none;
border-right: none;
border-top: none;
border-radius: 0;
padding: 0.25rem 1rem;
min-height: 2.5rem;
}
.comfy-vue-side-bar-header-span[data-v-1b0a8fe3] {
font-size: small;
}
.comfy-vue-side-bar-body[data-v-1b0a8fe3] {
flex-grow: 1;
overflow: auto;
scrollbar-width: thin;
scrollbar-color: transparent transparent;
}
.comfy-vue-side-bar-body[data-v-1b0a8fe3]::-webkit-scrollbar {
width: 1px;
}
.comfy-vue-side-bar-body[data-v-1b0a8fe3]::-webkit-scrollbar-thumb {
background-color: transparent;
}
.scroll-container[data-v-08fa89b1] {
height: 100%;
overflow-y: auto;
}
.queue-grid[data-v-08fa89b1] {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
padding: 0.5rem;
gap: 0.5rem;
}
.tree-node[data-v-633e27ab] {
width: 100%;
display: flex;
align-items: center;
justify-content: space-between;
}
.leaf-count-badge[data-v-633e27ab] {
margin-left: 0.5rem;
}
.node-content[data-v-633e27ab] {
display: flex;
align-items: center;
flex-grow: 1;
}
.leaf-label[data-v-633e27ab] {
margin-left: 0.5rem;
}
[data-v-633e27ab] .editable-text span {
word-break: break-all;
}
[data-v-bd7bae90] .tree-explorer-node-label {
width: 100%;
display: flex;
align-items: center;
margin-left: var(--p-tree-node-gap);
flex-grow: 1;
}
/*
* The following styles are necessary to avoid layout shift when dragging nodes over folders.
* By setting the position to relative on the parent and using an absolutely positioned pseudo-element,
* we can create a visual indicator for the drop target without affecting the layout of other elements.
*/
[data-v-bd7bae90] .p-tree-node-content:has(.tree-folder) {
position: relative;
}
[data-v-bd7bae90] .p-tree-node-content:has(.tree-folder.can-drop)::after {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
border: 1px solid var(--p-content-color);
pointer-events: none;
}
.node-lib-node-container[data-v-90dfee08] {
height: 100%;
width: 100%
}
.p-selectbutton .p-button[data-v-91077f2a] {
padding: 0.5rem;
}
.p-selectbutton .p-button .pi[data-v-91077f2a] {
font-size: 1.5rem;
}
.field[data-v-91077f2a] {
display: flex;
flex-direction: column;
gap: 0.5rem;
}
.color-picker-container[data-v-91077f2a] {
display: flex;
align-items: center;
gap: 0.5rem;
}
.node-lib-filter-popup {
margin-left: -13px;
}
[data-v-f6a7371a] .comfy-vue-side-bar-body {
background: var(--p-tree-background);
}
[data-v-f6a7371a] .node-lib-bookmark-tree-explorer {
padding-bottom: 2px;
}
[data-v-f6a7371a] .p-divider {
margin: var(--comfy-tree-explorer-item-padding) 0px;
}
.model_preview[data-v-32e6c4d9] {
background-color: var(--comfy-menu-bg);
font-family: 'Open Sans', sans-serif;
color: var(--descrip-text);
border: 1px solid var(--descrip-text);
min-width: 300px;
max-width: 500px;
width: -moz-fit-content;
width: fit-content;
height: -moz-fit-content;
height: fit-content;
z-index: 9999;
border-radius: 12px;
overflow: hidden;
font-size: 12px;
padding: 10px;
}
.model_preview_image[data-v-32e6c4d9] {
margin: auto;
width: -moz-fit-content;
width: fit-content;
}
.model_preview_image img[data-v-32e6c4d9] {
max-width: 100%;
max-height: 150px;
-o-object-fit: contain;
object-fit: contain;
}
.model_preview_title[data-v-32e6c4d9] {
font-weight: bold;
text-align: center;
font-size: 14px;
}
.model_preview_top_container[data-v-32e6c4d9] {
text-align: center;
line-height: 0.5;
}
.model_preview_filename[data-v-32e6c4d9],
.model_preview_author[data-v-32e6c4d9],
.model_preview_architecture[data-v-32e6c4d9] {
display: inline-block;
text-align: center;
margin: 5px;
font-size: 10px;
}
.model_preview_prefix[data-v-32e6c4d9] {
font-weight: bold;
}
.model-lib-model-icon-container[data-v-70b69131] {
display: inline-block;
position: relative;
left: 0;
height: 1.5rem;
vertical-align: top;
width: 0px;
}
.model-lib-model-icon[data-v-70b69131] {
background-size: cover;
background-position: center;
display: inline-block;
position: relative;
left: -2.5rem;
height: 2rem;
width: 2rem;
vertical-align: top;
}
.pi-fake-spacer {
height: 1px;
width: 16px;
}
[data-v-74b01bce] .comfy-vue-side-bar-body {
background: var(--p-tree-background);
}
[data-v-d2d58252] .comfy-vue-side-bar-body {
background: var(--p-tree-background);
}
[data-v-84e785b8] .p-togglebutton::before {
display: none
}
[data-v-84e785b8] .p-togglebutton {
position: relative;
flex-shrink: 0;
border-radius: 0px;
background-color: transparent;
padding-left: 0.5rem;
padding-right: 0.5rem
}
[data-v-84e785b8] .p-togglebutton.p-togglebutton-checked {
border-bottom-width: 2px;
border-bottom-color: var(--p-button-text-primary-color)
}
[data-v-84e785b8] .p-togglebutton-checked .close-button,[data-v-84e785b8] .p-togglebutton:hover .close-button {
visibility: visible
}
.status-indicator[data-v-84e785b8] {
position: absolute;
font-weight: 700;
font-size: 1.5rem;
top: 50%;
left: 50%;
transform: translate(-50%, -50%)
}
[data-v-84e785b8] .p-togglebutton:hover .status-indicator {
display: none
}
[data-v-84e785b8] .p-togglebutton .close-button {
visibility: hidden
}
.top-menubar[data-v-2ec1b620] .p-menubar-item-link svg {
display: none;
}
[data-v-2ec1b620] .p-menubar-submenu.dropdown-direction-up {
top: auto;
bottom: 100%;
flex-direction: column-reverse;
}
.keybinding-tag[data-v-2ec1b620] {
background: var(--p-content-hover-background);
border-color: var(--p-content-border-color);
border-style: solid;
}
[data-v-713442be] .p-inputtext {
border-top-left-radius: 0;
border-bottom-left-radius: 0;
}
.comfyui-queue-button[data-v-fcd3efcd] .p-splitbutton-dropdown {
border-top-right-radius: 0;
border-bottom-right-radius: 0;
}
.actionbar[data-v-bc6c78dd] {
pointer-events: all;
position: fixed;
z-index: 1000;
}
.actionbar.is-docked[data-v-bc6c78dd] {
position: static;
border-style: none;
background-color: transparent;
padding: 0px;
}
.actionbar.is-dragging[data-v-bc6c78dd] {
-webkit-user-select: none;
-moz-user-select: none;
user-select: none;
}
[data-v-bc6c78dd] .p-panel-content {
padding: 0.25rem;
}
[data-v-bc6c78dd] .p-panel-header {
display: none;
}
.comfyui-menu[data-v-b13fdc92] {
width: 100vw;
background: var(--comfy-menu-bg);
color: var(--fg-color);
font-family: Arial, Helvetica, sans-serif;
font-size: 0.8em;
box-sizing: border-box;
z-index: 1000;
order: 0;
grid-column: 1/-1;
max-height: 90vh;
}
.comfyui-menu.dropzone[data-v-b13fdc92] {
background: var(--p-highlight-background);
}
.comfyui-menu.dropzone-active[data-v-b13fdc92] {
background: var(--p-highlight-background-focus);
}
.comfyui-logo[data-v-b13fdc92] {
font-size: 1.2em;
-webkit-user-select: none;
-moz-user-select: none;
user-select: none;
cursor: default;
}

17465
web/assets/GraphView-CVV2XJjS.js generated vendored Normal file

File diff suppressed because one or more lines are too long

1
web/assets/GraphView-CVV2XJjS.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

865
web/assets/colorPalette-D5oi2-2V.js generated vendored Normal file
View File

@@ -0,0 +1,865 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { k as app, aP as LGraphCanvas, bO as useToastStore, ca as $el, z as LiteGraph } from "./index-DGAbdBYF.js";
const colorPalettes = {
dark: {
id: "dark",
name: "Dark (Default)",
colors: {
node_slot: {
CLIP: "#FFD500",
// bright yellow
CLIP_VISION: "#A8DADC",
// light blue-gray
CLIP_VISION_OUTPUT: "#ad7452",
// rusty brown-orange
CONDITIONING: "#FFA931",
// vibrant orange-yellow
CONTROL_NET: "#6EE7B7",
// soft mint green
IMAGE: "#64B5F6",
// bright sky blue
LATENT: "#FF9CF9",
// light pink-purple
MASK: "#81C784",
// muted green
MODEL: "#B39DDB",
// light lavender-purple
STYLE_MODEL: "#C2FFAE",
// light green-yellow
VAE: "#FF6E6E",
// bright red
NOISE: "#B0B0B0",
// gray
GUIDER: "#66FFFF",
// cyan
SAMPLER: "#ECB4B4",
// very soft red
SIGMAS: "#CDFFCD",
// soft lime green
TAESD: "#DCC274"
// cheesecake
},
litegraph_base: {
BACKGROUND_IMAGE: "",
CLEAR_BACKGROUND_COLOR: "#222",
NODE_TITLE_COLOR: "#999",
NODE_SELECTED_TITLE_COLOR: "#FFF",
NODE_TEXT_SIZE: 14,
NODE_TEXT_COLOR: "#AAA",
NODE_SUBTEXT_SIZE: 12,
NODE_DEFAULT_COLOR: "#333",
NODE_DEFAULT_BGCOLOR: "#353535",
NODE_DEFAULT_BOXCOLOR: "#666",
NODE_DEFAULT_SHAPE: "box",
NODE_BOX_OUTLINE_COLOR: "#FFF",
NODE_BYPASS_BGCOLOR: "#FF00FF",
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
DEFAULT_GROUP_FONT: 24,
WIDGET_BGCOLOR: "#222",
WIDGET_OUTLINE_COLOR: "#666",
WIDGET_TEXT_COLOR: "#DDD",
WIDGET_SECONDARY_TEXT_COLOR: "#999",
LINK_COLOR: "#9A9",
EVENT_LINK_COLOR: "#A86",
CONNECTING_LINK_COLOR: "#AFA",
BADGE_FG_COLOR: "#FFF",
BADGE_BG_COLOR: "#0F1F0F"
},
comfy_base: {
"fg-color": "#fff",
"bg-color": "#202020",
"comfy-menu-bg": "#353535",
"comfy-input-bg": "#222",
"input-text": "#ddd",
"descrip-text": "#999",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#4e4e4e",
"tr-even-bg-color": "#222",
"tr-odd-bg-color": "#353535",
"content-bg": "#4e4e4e",
"content-fg": "#fff",
"content-hover-bg": "#222",
"content-hover-fg": "#fff"
}
}
},
light: {
id: "light",
name: "Light",
colors: {
node_slot: {
CLIP: "#FFA726",
// orange
CLIP_VISION: "#5C6BC0",
// indigo
CLIP_VISION_OUTPUT: "#8D6E63",
// brown
CONDITIONING: "#EF5350",
// red
CONTROL_NET: "#66BB6A",
// green
IMAGE: "#42A5F5",
// blue
LATENT: "#AB47BC",
// purple
MASK: "#9CCC65",
// light green
MODEL: "#7E57C2",
// deep purple
STYLE_MODEL: "#D4E157",
// lime
VAE: "#FF7043"
// deep orange
},
litegraph_base: {
BACKGROUND_IMAGE: "",
CLEAR_BACKGROUND_COLOR: "lightgray",
NODE_TITLE_COLOR: "#222",
NODE_SELECTED_TITLE_COLOR: "#000",
NODE_TEXT_SIZE: 14,
NODE_TEXT_COLOR: "#444",
NODE_SUBTEXT_SIZE: 12,
NODE_DEFAULT_COLOR: "#F7F7F7",
NODE_DEFAULT_BGCOLOR: "#F5F5F5",
NODE_DEFAULT_BOXCOLOR: "#CCC",
NODE_DEFAULT_SHAPE: "box",
NODE_BOX_OUTLINE_COLOR: "#000",
NODE_BYPASS_BGCOLOR: "#FF00FF",
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.1)",
DEFAULT_GROUP_FONT: 24,
WIDGET_BGCOLOR: "#D4D4D4",
WIDGET_OUTLINE_COLOR: "#999",
WIDGET_TEXT_COLOR: "#222",
WIDGET_SECONDARY_TEXT_COLOR: "#555",
LINK_COLOR: "#4CAF50",
EVENT_LINK_COLOR: "#FF9800",
CONNECTING_LINK_COLOR: "#2196F3",
BADGE_FG_COLOR: "#000",
BADGE_BG_COLOR: "#FFF"
},
comfy_base: {
"fg-color": "#222",
"bg-color": "#DDD",
"comfy-menu-bg": "#F5F5F5",
"comfy-input-bg": "#C9C9C9",
"input-text": "#222",
"descrip-text": "#444",
"drag-text": "#555",
"error-text": "#F44336",
"border-color": "#888",
"tr-even-bg-color": "#f9f9f9",
"tr-odd-bg-color": "#fff",
"content-bg": "#e0e0e0",
"content-fg": "#222",
"content-hover-bg": "#adadad",
"content-hover-fg": "#222"
}
}
},
solarized: {
id: "solarized",
name: "Solarized",
colors: {
node_slot: {
CLIP: "#2AB7CA",
// light blue
CLIP_VISION: "#6c71c4",
// blue violet
CLIP_VISION_OUTPUT: "#859900",
// olive green
CONDITIONING: "#d33682",
// magenta
CONTROL_NET: "#d1ffd7",
// light mint green
IMAGE: "#5940bb",
// deep blue violet
LATENT: "#268bd2",
// blue
MASK: "#CCC9E7",
// light purple-gray
MODEL: "#dc322f",
// red
STYLE_MODEL: "#1a998a",
// teal
UPSCALE_MODEL: "#054A29",
// dark green
VAE: "#facfad"
// light pink-orange
},
litegraph_base: {
NODE_TITLE_COLOR: "#fdf6e3",
// Base3
NODE_SELECTED_TITLE_COLOR: "#A9D400",
NODE_TEXT_SIZE: 14,
NODE_TEXT_COLOR: "#657b83",
// Base00
NODE_SUBTEXT_SIZE: 12,
NODE_DEFAULT_COLOR: "#094656",
NODE_DEFAULT_BGCOLOR: "#073642",
// Base02
NODE_DEFAULT_BOXCOLOR: "#839496",
// Base0
NODE_DEFAULT_SHAPE: "box",
NODE_BOX_OUTLINE_COLOR: "#fdf6e3",
// Base3
NODE_BYPASS_BGCOLOR: "#FF00FF",
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
DEFAULT_GROUP_FONT: 24,
WIDGET_BGCOLOR: "#002b36",
// Base03
WIDGET_OUTLINE_COLOR: "#839496",
// Base0
WIDGET_TEXT_COLOR: "#fdf6e3",
// Base3
WIDGET_SECONDARY_TEXT_COLOR: "#93a1a1",
// Base1
LINK_COLOR: "#2aa198",
// Solarized Cyan
EVENT_LINK_COLOR: "#268bd2",
// Solarized Blue
CONNECTING_LINK_COLOR: "#859900"
// Solarized Green
},
comfy_base: {
"fg-color": "#fdf6e3",
// Base3
"bg-color": "#002b36",
// Base03
"comfy-menu-bg": "#073642",
// Base02
"comfy-input-bg": "#002b36",
// Base03
"input-text": "#93a1a1",
// Base1
"descrip-text": "#586e75",
// Base01
"drag-text": "#839496",
// Base0
"error-text": "#dc322f",
// Solarized Red
"border-color": "#657b83",
// Base00
"tr-even-bg-color": "#002b36",
"tr-odd-bg-color": "#073642",
"content-bg": "#657b83",
"content-fg": "#fdf6e3",
"content-hover-bg": "#002b36",
"content-hover-fg": "#fdf6e3"
}
}
},
arc: {
id: "arc",
name: "Arc",
colors: {
node_slot: {
BOOLEAN: "",
CLIP: "#eacb8b",
CLIP_VISION: "#A8DADC",
CLIP_VISION_OUTPUT: "#ad7452",
CONDITIONING: "#cf876f",
CONTROL_NET: "#00d78d",
CONTROL_NET_WEIGHTS: "",
FLOAT: "",
GLIGEN: "",
IMAGE: "#80a1c0",
IMAGEUPLOAD: "",
INT: "",
LATENT: "#b38ead",
LATENT_KEYFRAME: "",
MASK: "#a3bd8d",
MODEL: "#8978a7",
SAMPLER: "",
SIGMAS: "",
STRING: "",
STYLE_MODEL: "#C2FFAE",
T2I_ADAPTER_WEIGHTS: "",
TAESD: "#DCC274",
TIMESTEP_KEYFRAME: "",
UPSCALE_MODEL: "",
VAE: "#be616b"
},
litegraph_base: {
BACKGROUND_IMAGE: "",
CLEAR_BACKGROUND_COLOR: "#2b2f38",
NODE_TITLE_COLOR: "#b2b7bd",
NODE_SELECTED_TITLE_COLOR: "#FFF",
NODE_TEXT_SIZE: 14,
NODE_TEXT_COLOR: "#AAA",
NODE_SUBTEXT_SIZE: 12,
NODE_DEFAULT_COLOR: "#2b2f38",
NODE_DEFAULT_BGCOLOR: "#242730",
NODE_DEFAULT_BOXCOLOR: "#6e7581",
NODE_DEFAULT_SHAPE: "box",
NODE_BOX_OUTLINE_COLOR: "#FFF",
NODE_BYPASS_BGCOLOR: "#FF00FF",
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
DEFAULT_GROUP_FONT: 22,
WIDGET_BGCOLOR: "#2b2f38",
WIDGET_OUTLINE_COLOR: "#6e7581",
WIDGET_TEXT_COLOR: "#DDD",
WIDGET_SECONDARY_TEXT_COLOR: "#b2b7bd",
LINK_COLOR: "#9A9",
EVENT_LINK_COLOR: "#A86",
CONNECTING_LINK_COLOR: "#AFA"
},
comfy_base: {
"fg-color": "#fff",
"bg-color": "#2b2f38",
"comfy-menu-bg": "#242730",
"comfy-input-bg": "#2b2f38",
"input-text": "#ddd",
"descrip-text": "#b2b7bd",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#6e7581",
"tr-even-bg-color": "#2b2f38",
"tr-odd-bg-color": "#242730",
"content-bg": "#6e7581",
"content-fg": "#fff",
"content-hover-bg": "#2b2f38",
"content-hover-fg": "#fff"
}
}
},
nord: {
id: "nord",
name: "Nord",
colors: {
node_slot: {
BOOLEAN: "",
CLIP: "#eacb8b",
CLIP_VISION: "#A8DADC",
CLIP_VISION_OUTPUT: "#ad7452",
CONDITIONING: "#cf876f",
CONTROL_NET: "#00d78d",
CONTROL_NET_WEIGHTS: "",
FLOAT: "",
GLIGEN: "",
IMAGE: "#80a1c0",
IMAGEUPLOAD: "",
INT: "",
LATENT: "#b38ead",
LATENT_KEYFRAME: "",
MASK: "#a3bd8d",
MODEL: "#8978a7",
SAMPLER: "",
SIGMAS: "",
STRING: "",
STYLE_MODEL: "#C2FFAE",
T2I_ADAPTER_WEIGHTS: "",
TAESD: "#DCC274",
TIMESTEP_KEYFRAME: "",
UPSCALE_MODEL: "",
VAE: "#be616b"
},
litegraph_base: {
BACKGROUND_IMAGE: "",
CLEAR_BACKGROUND_COLOR: "#212732",
NODE_TITLE_COLOR: "#999",
NODE_SELECTED_TITLE_COLOR: "#e5eaf0",
NODE_TEXT_SIZE: 14,
NODE_TEXT_COLOR: "#bcc2c8",
NODE_SUBTEXT_SIZE: 12,
NODE_DEFAULT_COLOR: "#2e3440",
NODE_DEFAULT_BGCOLOR: "#161b22",
NODE_DEFAULT_BOXCOLOR: "#545d70",
NODE_DEFAULT_SHAPE: "box",
NODE_BOX_OUTLINE_COLOR: "#e5eaf0",
NODE_BYPASS_BGCOLOR: "#FF00FF",
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
DEFAULT_GROUP_FONT: 24,
WIDGET_BGCOLOR: "#2e3440",
WIDGET_OUTLINE_COLOR: "#545d70",
WIDGET_TEXT_COLOR: "#bcc2c8",
WIDGET_SECONDARY_TEXT_COLOR: "#999",
LINK_COLOR: "#9A9",
EVENT_LINK_COLOR: "#A86",
CONNECTING_LINK_COLOR: "#AFA"
},
comfy_base: {
"fg-color": "#e5eaf0",
"bg-color": "#2e3440",
"comfy-menu-bg": "#161b22",
"comfy-input-bg": "#2e3440",
"input-text": "#bcc2c8",
"descrip-text": "#999",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#545d70",
"tr-even-bg-color": "#2e3440",
"tr-odd-bg-color": "#161b22",
"content-bg": "#545d70",
"content-fg": "#e5eaf0",
"content-hover-bg": "#2e3440",
"content-hover-fg": "#e5eaf0"
}
}
},
github: {
id: "github",
name: "Github",
colors: {
node_slot: {
BOOLEAN: "",
CLIP: "#eacb8b",
CLIP_VISION: "#A8DADC",
CLIP_VISION_OUTPUT: "#ad7452",
CONDITIONING: "#cf876f",
CONTROL_NET: "#00d78d",
CONTROL_NET_WEIGHTS: "",
FLOAT: "",
GLIGEN: "",
IMAGE: "#80a1c0",
IMAGEUPLOAD: "",
INT: "",
LATENT: "#b38ead",
LATENT_KEYFRAME: "",
MASK: "#a3bd8d",
MODEL: "#8978a7",
SAMPLER: "",
SIGMAS: "",
STRING: "",
STYLE_MODEL: "#C2FFAE",
T2I_ADAPTER_WEIGHTS: "",
TAESD: "#DCC274",
TIMESTEP_KEYFRAME: "",
UPSCALE_MODEL: "",
VAE: "#be616b"
},
litegraph_base: {
BACKGROUND_IMAGE: "",
CLEAR_BACKGROUND_COLOR: "#040506",
NODE_TITLE_COLOR: "#999",
NODE_SELECTED_TITLE_COLOR: "#e5eaf0",
NODE_TEXT_SIZE: 14,
NODE_TEXT_COLOR: "#bcc2c8",
NODE_SUBTEXT_SIZE: 12,
NODE_DEFAULT_COLOR: "#161b22",
NODE_DEFAULT_BGCOLOR: "#13171d",
NODE_DEFAULT_BOXCOLOR: "#30363d",
NODE_DEFAULT_SHAPE: "box",
NODE_BOX_OUTLINE_COLOR: "#e5eaf0",
NODE_BYPASS_BGCOLOR: "#FF00FF",
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
DEFAULT_GROUP_FONT: 24,
WIDGET_BGCOLOR: "#161b22",
WIDGET_OUTLINE_COLOR: "#30363d",
WIDGET_TEXT_COLOR: "#bcc2c8",
WIDGET_SECONDARY_TEXT_COLOR: "#999",
LINK_COLOR: "#9A9",
EVENT_LINK_COLOR: "#A86",
CONNECTING_LINK_COLOR: "#AFA"
},
comfy_base: {
"fg-color": "#e5eaf0",
"bg-color": "#161b22",
"comfy-menu-bg": "#13171d",
"comfy-input-bg": "#161b22",
"input-text": "#bcc2c8",
"descrip-text": "#999",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#30363d",
"tr-even-bg-color": "#161b22",
"tr-odd-bg-color": "#13171d",
"content-bg": "#30363d",
"content-fg": "#e5eaf0",
"content-hover-bg": "#161b22",
"content-hover-fg": "#e5eaf0"
}
}
}
};
const id = "Comfy.ColorPalette";
const idCustomColorPalettes = "Comfy.CustomColorPalettes";
const defaultColorPaletteId = "dark";
const els = {
select: null
};
const getCustomColorPalettes = /* @__PURE__ */ __name(() => {
return app.ui.settings.getSettingValue(idCustomColorPalettes, {});
}, "getCustomColorPalettes");
const setCustomColorPalettes = /* @__PURE__ */ __name((customColorPalettes) => {
return app.ui.settings.setSettingValue(
idCustomColorPalettes,
customColorPalettes
);
}, "setCustomColorPalettes");
const defaultColorPalette = colorPalettes[defaultColorPaletteId];
const getColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
if (!colorPaletteId) {
colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
}
if (colorPaletteId.startsWith("custom_")) {
colorPaletteId = colorPaletteId.substr(7);
let customColorPalettes = getCustomColorPalettes();
if (customColorPalettes[colorPaletteId]) {
return customColorPalettes[colorPaletteId];
}
}
return colorPalettes[colorPaletteId];
}, "getColorPalette");
const setColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
app.ui.settings.setSettingValue(id, colorPaletteId);
}, "setColorPalette");
app.registerExtension({
name: id,
init() {
LGraphCanvas.prototype.updateBackground = function(image, clearBackgroundColor) {
this._bg_img = new Image();
this._bg_img.name = image;
this._bg_img.src = image;
this._bg_img.onload = () => {
this.draw(true, true);
};
this.background_image = image;
this.clear_background = true;
this.clear_background_color = clearBackgroundColor;
this._pattern = null;
};
},
addCustomNodeDefs(node_defs) {
const sortObjectKeys = /* @__PURE__ */ __name((unordered) => {
return Object.keys(unordered).sort().reduce((obj, key) => {
obj[key] = unordered[key];
return obj;
}, {});
}, "sortObjectKeys");
function getSlotTypes() {
var types = [];
const defs = node_defs;
for (const nodeId in defs) {
const nodeData = defs[nodeId];
var inputs = nodeData["input"]["required"];
if (nodeData["input"]["optional"] !== void 0) {
inputs = Object.assign(
{},
nodeData["input"]["required"],
nodeData["input"]["optional"]
);
}
for (const inputName in inputs) {
const inputData = inputs[inputName];
const type = inputData[0];
if (!Array.isArray(type)) {
types.push(type);
}
}
for (const o in nodeData["output"]) {
const output = nodeData["output"][o];
types.push(output);
}
}
return types;
}
__name(getSlotTypes, "getSlotTypes");
function completeColorPalette(colorPalette) {
var types = getSlotTypes();
for (const type of types) {
if (!colorPalette.colors.node_slot[type]) {
colorPalette.colors.node_slot[type] = "";
}
}
colorPalette.colors.node_slot = sortObjectKeys(
colorPalette.colors.node_slot
);
return colorPalette;
}
__name(completeColorPalette, "completeColorPalette");
const getColorPaletteTemplate = /* @__PURE__ */ __name(async () => {
let colorPalette = {
id: "my_color_palette_unique_id",
name: "My Color Palette",
colors: {
node_slot: {},
litegraph_base: {},
comfy_base: {}
}
};
const defaultColorPalette2 = colorPalettes[defaultColorPaletteId];
for (const key in defaultColorPalette2.colors.litegraph_base) {
if (!colorPalette.colors.litegraph_base[key]) {
colorPalette.colors.litegraph_base[key] = "";
}
}
for (const key in defaultColorPalette2.colors.comfy_base) {
if (!colorPalette.colors.comfy_base[key]) {
colorPalette.colors.comfy_base[key] = "";
}
}
return completeColorPalette(colorPalette);
}, "getColorPaletteTemplate");
const addCustomColorPalette = /* @__PURE__ */ __name(async (colorPalette) => {
if (typeof colorPalette !== "object") {
useToastStore().addAlert("Invalid color palette.");
return;
}
if (!colorPalette.id) {
useToastStore().addAlert("Color palette missing id.");
return;
}
if (!colorPalette.name) {
useToastStore().addAlert("Color palette missing name.");
return;
}
if (!colorPalette.colors) {
useToastStore().addAlert("Color palette missing colors.");
return;
}
if (colorPalette.colors.node_slot && typeof colorPalette.colors.node_slot !== "object") {
useToastStore().addAlert("Invalid color palette colors.node_slot.");
return;
}
const customColorPalettes = getCustomColorPalettes();
customColorPalettes[colorPalette.id] = colorPalette;
setCustomColorPalettes(customColorPalettes);
for (const option of els.select.childNodes) {
if (option.value === "custom_" + colorPalette.id) {
els.select.removeChild(option);
}
}
els.select.append(
$el("option", {
textContent: colorPalette.name + " (custom)",
value: "custom_" + colorPalette.id,
selected: true
})
);
setColorPalette("custom_" + colorPalette.id);
await loadColorPalette(colorPalette);
}, "addCustomColorPalette");
const deleteCustomColorPalette = /* @__PURE__ */ __name(async (colorPaletteId) => {
const customColorPalettes = getCustomColorPalettes();
delete customColorPalettes[colorPaletteId];
setCustomColorPalettes(customColorPalettes);
for (const opt of els.select.childNodes) {
const option = opt;
if (option.value === defaultColorPaletteId) {
option.selected = true;
}
if (option.value === "custom_" + colorPaletteId) {
els.select.removeChild(option);
}
}
setColorPalette(defaultColorPaletteId);
await loadColorPalette(getColorPalette());
}, "deleteCustomColorPalette");
const loadColorPalette = /* @__PURE__ */ __name(async (colorPalette) => {
colorPalette = await completeColorPalette(colorPalette);
if (colorPalette.colors) {
if (colorPalette.colors.node_slot) {
Object.assign(
app.canvas.default_connection_color_byType,
colorPalette.colors.node_slot
);
Object.assign(
LGraphCanvas.link_type_colors,
colorPalette.colors.node_slot
);
}
if (colorPalette.colors.litegraph_base) {
app.canvas.node_title_color = colorPalette.colors.litegraph_base.NODE_TITLE_COLOR;
app.canvas.default_link_color = colorPalette.colors.litegraph_base.LINK_COLOR;
for (const key in colorPalette.colors.litegraph_base) {
if (colorPalette.colors.litegraph_base.hasOwnProperty(key) && LiteGraph.hasOwnProperty(key)) {
LiteGraph[key] = colorPalette.colors.litegraph_base[key];
}
}
}
if (colorPalette.colors.comfy_base) {
const rootStyle = document.documentElement.style;
for (const key in colorPalette.colors.comfy_base) {
rootStyle.setProperty(
"--" + key,
colorPalette.colors.comfy_base[key]
);
}
}
if (colorPalette.colors.litegraph_base.NODE_BYPASS_BGCOLOR) {
app.bypassBgColor = colorPalette.colors.litegraph_base.NODE_BYPASS_BGCOLOR;
}
app.canvas.draw(true, true);
}
}, "loadColorPalette");
const fileInput = $el("input", {
type: "file",
accept: ".json",
style: { display: "none" },
parent: document.body,
onchange: /* @__PURE__ */ __name(() => {
const file = fileInput.files[0];
if (file.type === "application/json" || file.name.endsWith(".json")) {
const reader = new FileReader();
reader.onload = async () => {
await addCustomColorPalette(JSON.parse(reader.result));
};
reader.readAsText(file);
}
}, "onchange")
});
app.ui.settings.addSetting({
id,
category: ["Comfy", "ColorPalette"],
name: "Color Palette",
type: /* @__PURE__ */ __name((name, setter, value) => {
const options = [
...Object.values(colorPalettes).map(
(c) => $el("option", {
textContent: c.name,
value: c.id,
selected: c.id === value
})
),
...Object.values(getCustomColorPalettes()).map(
(c) => $el("option", {
textContent: `${c.name} (custom)`,
value: `custom_${c.id}`,
selected: `custom_${c.id}` === value
})
)
];
els.select = $el(
"select",
{
style: {
marginBottom: "0.15rem",
width: "100%"
},
onchange: /* @__PURE__ */ __name((e) => {
setter(e.target.value);
}, "onchange")
},
options
);
return $el("tr", [
$el("td", [
els.select,
$el(
"div",
{
style: {
display: "grid",
gap: "4px",
gridAutoFlow: "column"
}
},
[
$el("input", {
type: "button",
value: "Export",
onclick: /* @__PURE__ */ __name(async () => {
const colorPaletteId = app.ui.settings.getSettingValue(
id,
defaultColorPaletteId
);
const colorPalette = await completeColorPalette(
getColorPalette(colorPaletteId)
);
const json = JSON.stringify(colorPalette, null, 2);
const blob = new Blob([json], { type: "application/json" });
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: colorPaletteId + ".json",
style: { display: "none" },
parent: document.body
});
a.click();
setTimeout(function() {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
}, "onclick")
}),
$el("input", {
type: "button",
value: "Import",
onclick: /* @__PURE__ */ __name(() => {
fileInput.click();
}, "onclick")
}),
$el("input", {
type: "button",
value: "Template",
onclick: /* @__PURE__ */ __name(async () => {
const colorPalette = await getColorPaletteTemplate();
const json = JSON.stringify(colorPalette, null, 2);
const blob = new Blob([json], { type: "application/json" });
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: "color_palette.json",
style: { display: "none" },
parent: document.body
});
a.click();
setTimeout(function() {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
}, "onclick")
}),
$el("input", {
type: "button",
value: "Delete",
onclick: /* @__PURE__ */ __name(async () => {
let colorPaletteId = app.ui.settings.getSettingValue(
id,
defaultColorPaletteId
);
if (colorPalettes[colorPaletteId]) {
useToastStore().addAlert(
"You cannot delete a built-in color palette."
);
return;
}
if (colorPaletteId.startsWith("custom_")) {
colorPaletteId = colorPaletteId.substr(7);
}
await deleteCustomColorPalette(colorPaletteId);
}, "onclick")
})
]
)
])
]);
}, "type"),
defaultValue: defaultColorPaletteId,
async onChange(value) {
if (!value) {
return;
}
let palette = colorPalettes[value];
if (palette) {
await loadColorPalette(palette);
} else if (value.startsWith("custom_")) {
value = value.substr(7);
let customColorPalettes = getCustomColorPalettes();
if (customColorPalettes[value]) {
palette = customColorPalettes[value];
await loadColorPalette(customColorPalettes[value]);
}
}
let { BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR } = palette.colors.litegraph_base;
if (BACKGROUND_IMAGE === void 0 || CLEAR_BACKGROUND_COLOR === void 0) {
const base = colorPalettes["dark"].colors.litegraph_base;
BACKGROUND_IMAGE = base.BACKGROUND_IMAGE;
CLEAR_BACKGROUND_COLOR = base.CLEAR_BACKGROUND_COLOR;
}
app.canvas.updateBackground(BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR);
}
});
}
});
window.comfyAPI = window.comfyAPI || {};
window.comfyAPI.colorPalette = window.comfyAPI.colorPalette || {};
window.comfyAPI.colorPalette.defaultColorPalette = defaultColorPalette;
window.comfyAPI.colorPalette.getColorPalette = getColorPalette;
export {
defaultColorPalette as d,
getColorPalette as g
};
//# sourceMappingURL=colorPalette-D5oi2-2V.js.map

1
web/assets/colorPalette-D5oi2-2V.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

1
web/assets/index-BD-Ia1C4.js.map generated vendored

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

1
web/assets/index-BMC1ey-i.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

1
web/assets/index-CI3N807S.js.map generated vendored

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

1
web/assets/index-DGAbdBYF.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

2602
web/assets/sorted-custom-node-map.json generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,37 @@
.lds-ring {
display: inline-block;
position: relative;
width: 1em;
height: 1em;
}
.lds-ring div {
box-sizing: border-box;
display: block;
position: absolute;
width: 100%;
height: 100%;
border: 0.15em solid #fff;
border-radius: 50%;
animation: lds-ring 1.2s cubic-bezier(0.5, 0, 0.5, 1) infinite;
border-color: #fff transparent transparent transparent;
}
.lds-ring div:nth-child(1) {
animation-delay: -0.45s;
}
.lds-ring div:nth-child(2) {
animation-delay: -0.3s;
}
.lds-ring div:nth-child(3) {
animation-delay: -0.15s;
}
@keyframes lds-ring {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}
.comfy-user-selection { .comfy-user-selection {
width: 100vw; width: 100vw;
height: 100vh; height: 100vh;

File diff suppressed because one or more lines are too long

View File

@@ -1,6 +1,15 @@
var __defProp = Object.defineProperty; var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true }); var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { j as createSpinner, h as api, $ as $el } from "./index-CI3N807S.js"; import { b4 as api, ca as $el } from "./index-DGAbdBYF.js";
function createSpinner() {
const div = document.createElement("div");
div.innerHTML = `<div class="lds-ring"><div></div><div></div><div></div><div></div></div>`;
return div.firstElementChild;
}
__name(createSpinner, "createSpinner");
window.comfyAPI = window.comfyAPI || {};
window.comfyAPI.spinner = window.comfyAPI.spinner || {};
window.comfyAPI.spinner.createSpinner = createSpinner;
class UserSelectionScreen { class UserSelectionScreen {
static { static {
__name(this, "UserSelectionScreen"); __name(this, "UserSelectionScreen");
@@ -117,4 +126,4 @@ window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
export { export {
UserSelectionScreen UserSelectionScreen
}; };
//# sourceMappingURL=userSelection-CyXKCVy3.js.map //# sourceMappingURL=userSelection-Duxc-t_S.js.map

1
web/assets/userSelection-Duxc-t_S.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

756
web/assets/widgetInputs-DdoWwzg5.js generated vendored Normal file
View File

@@ -0,0 +1,756 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { l as LGraphNode, k as app, cf as applyTextReplacements, ce as ComfyWidgets, ci as addValueControlWidgets, z as LiteGraph } from "./index-DGAbdBYF.js";
const CONVERTED_TYPE = "converted-widget";
const VALID_TYPES = [
"STRING",
"combo",
"number",
"toggle",
"BOOLEAN",
"text",
"string"
];
const CONFIG = Symbol();
const GET_CONFIG = Symbol();
const TARGET = Symbol();
const replacePropertyName = "Run widget replace on values";
class PrimitiveNode extends LGraphNode {
static {
__name(this, "PrimitiveNode");
}
controlValues;
lastType;
static category;
constructor(title) {
super(title);
this.addOutput("connect to widget input", "*");
this.serialize_widgets = true;
this.isVirtualNode = true;
if (!this.properties || !(replacePropertyName in this.properties)) {
this.addProperty(replacePropertyName, false, "boolean");
}
}
applyToGraph(extraLinks = []) {
if (!this.outputs[0].links?.length) return;
function get_links(node) {
let links2 = [];
for (const l of node.outputs[0].links) {
const linkInfo = app.graph.links[l];
const n = node.graph.getNodeById(linkInfo.target_id);
if (n.type == "Reroute") {
links2 = links2.concat(get_links(n));
} else {
links2.push(l);
}
}
return links2;
}
__name(get_links, "get_links");
let links = [
...get_links(this).map((l) => app.graph.links[l]),
...extraLinks
];
let v = this.widgets?.[0].value;
if (v && this.properties[replacePropertyName]) {
v = applyTextReplacements(app, v);
}
for (const linkInfo of links) {
const node = this.graph.getNodeById(linkInfo.target_id);
const input = node.inputs[linkInfo.target_slot];
let widget;
if (input.widget[TARGET]) {
widget = input.widget[TARGET];
} else {
const widgetName = input.widget.name;
if (widgetName) {
widget = node.widgets.find((w) => w.name === widgetName);
}
}
if (widget) {
widget.value = v;
if (widget.callback) {
widget.callback(
widget.value,
app.canvas,
node,
app.canvas.graph_mouse,
{}
);
}
}
}
}
refreshComboInNode() {
const widget = this.widgets?.[0];
if (widget?.type === "combo") {
widget.options.values = this.outputs[0].widget[GET_CONFIG]()[0];
if (!widget.options.values.includes(widget.value)) {
widget.value = widget.options.values[0];
widget.callback(widget.value);
}
}
}
onAfterGraphConfigured() {
if (this.outputs[0].links?.length && !this.widgets?.length) {
if (!this.#onFirstConnection()) return;
if (this.widgets) {
for (let i = 0; i < this.widgets_values.length; i++) {
const w = this.widgets[i];
if (w) {
w.value = this.widgets_values[i];
}
}
}
this.#mergeWidgetConfig();
}
}
onConnectionsChange(_, index, connected) {
if (app.configuringGraph) {
return;
}
const links = this.outputs[0].links;
if (connected) {
if (links?.length && !this.widgets?.length) {
this.#onFirstConnection();
}
} else {
this.#mergeWidgetConfig();
if (!links?.length) {
this.onLastDisconnect();
}
}
}
onConnectOutput(slot, type, input, target_node, target_slot) {
if (!input.widget) {
if (!(input.type in ComfyWidgets)) return false;
}
if (this.outputs[slot].links?.length) {
const valid = this.#isValidConnection(input);
if (valid) {
this.applyToGraph([{ target_id: target_node.id, target_slot }]);
}
return valid;
}
}
#onFirstConnection(recreating) {
if (!this.outputs[0].links) {
this.onLastDisconnect();
return;
}
const linkId = this.outputs[0].links[0];
const link = this.graph.links[linkId];
if (!link) return;
const theirNode = this.graph.getNodeById(link.target_id);
if (!theirNode || !theirNode.inputs) return;
const input = theirNode.inputs[link.target_slot];
if (!input) return;
let widget;
if (!input.widget) {
if (!(input.type in ComfyWidgets)) return;
widget = { name: input.name, [GET_CONFIG]: () => [input.type, {}] };
} else {
widget = input.widget;
}
const config = widget[GET_CONFIG]?.();
if (!config) return;
const { type } = getWidgetType(config);
this.outputs[0].type = type;
this.outputs[0].name = type;
this.outputs[0].widget = widget;
this.#createWidget(
widget[CONFIG] ?? config,
theirNode,
widget.name,
recreating,
widget[TARGET]
);
}
#createWidget(inputData, node, widgetName, recreating, targetWidget) {
let type = inputData[0];
if (type instanceof Array) {
type = "COMBO";
}
const size = this.size;
let widget;
if (type in ComfyWidgets) {
widget = (ComfyWidgets[type](this, "value", inputData, app) || {}).widget;
} else {
widget = this.addWidget(type, "value", null, () => {
}, {});
}
if (targetWidget) {
widget.value = targetWidget.value;
} else if (node?.widgets && widget) {
const theirWidget = node.widgets.find((w) => w.name === widgetName);
if (theirWidget) {
widget.value = theirWidget.value;
}
}
if (!inputData?.[1]?.control_after_generate && (widget.type === "number" || widget.type === "combo")) {
let control_value = this.widgets_values?.[1];
if (!control_value) {
control_value = "fixed";
}
addValueControlWidgets(
this,
widget,
control_value,
void 0,
inputData
);
let filter = this.widgets_values?.[2];
if (filter && this.widgets.length === 3) {
this.widgets[2].value = filter;
}
}
const controlValues = this.controlValues;
if (this.lastType === this.widgets[0].type && controlValues?.length === this.widgets.length - 1) {
for (let i = 0; i < controlValues.length; i++) {
this.widgets[i + 1].value = controlValues[i];
}
}
const callback = widget.callback;
const self = this;
widget.callback = function() {
const r = callback ? callback.apply(this, arguments) : void 0;
self.applyToGraph();
return r;
};
this.size = [
Math.max(this.size[0], size[0]),
Math.max(this.size[1], size[1])
];
if (!recreating) {
const sz = this.computeSize();
if (this.size[0] < sz[0]) {
this.size[0] = sz[0];
}
if (this.size[1] < sz[1]) {
this.size[1] = sz[1];
}
requestAnimationFrame(() => {
if (this.onResize) {
this.onResize(this.size);
}
});
}
}
recreateWidget() {
const values = this.widgets?.map((w) => w.value);
this.#removeWidgets();
this.#onFirstConnection(true);
if (values?.length) {
for (let i = 0; i < this.widgets?.length; i++)
this.widgets[i].value = values[i];
}
return this.widgets?.[0];
}
#mergeWidgetConfig() {
const output = this.outputs[0];
const links = output.links;
const hasConfig = !!output.widget[CONFIG];
if (hasConfig) {
delete output.widget[CONFIG];
}
if (links?.length < 2 && hasConfig) {
if (links.length) {
this.recreateWidget();
}
return;
}
const config1 = output.widget[GET_CONFIG]();
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
if (!isNumber) return;
for (const linkId of links) {
const link = app.graph.links[linkId];
if (!link) continue;
const theirNode = app.graph.getNodeById(link.target_id);
const theirInput = theirNode.inputs[link.target_slot];
this.#isValidConnection(theirInput, hasConfig);
}
}
isValidWidgetLink(originSlot, targetNode, targetWidget) {
const config2 = getConfig.call(targetNode, targetWidget.name) ?? [
targetWidget.type,
targetWidget.options || {}
];
if (!isConvertibleWidget(targetWidget, config2)) return false;
const output = this.outputs[originSlot];
if (!(output.widget?.[CONFIG] ?? output.widget?.[GET_CONFIG]())) {
return true;
}
return !!mergeIfValid.call(this, output, config2);
}
#isValidConnection(input, forceUpdate) {
const output = this.outputs[0];
const config2 = input.widget[GET_CONFIG]();
return !!mergeIfValid.call(
this,
output,
config2,
forceUpdate,
this.recreateWidget
);
}
#removeWidgets() {
if (this.widgets) {
for (const w of this.widgets) {
if (w.onRemove) {
w.onRemove();
}
}
this.controlValues = [];
this.lastType = this.widgets[0]?.type;
for (let i = 1; i < this.widgets.length; i++) {
this.controlValues.push(this.widgets[i].value);
}
setTimeout(() => {
delete this.lastType;
delete this.controlValues;
}, 15);
this.widgets.length = 0;
}
}
onLastDisconnect() {
this.outputs[0].type = "*";
this.outputs[0].name = "connect to widget input";
delete this.outputs[0].widget;
this.#removeWidgets();
}
}
function getWidgetConfig(slot) {
return slot.widget[CONFIG] ?? slot.widget[GET_CONFIG]();
}
__name(getWidgetConfig, "getWidgetConfig");
function getConfig(widgetName) {
const { nodeData } = this.constructor;
return nodeData?.input?.required?.[widgetName] ?? nodeData?.input?.optional?.[widgetName];
}
__name(getConfig, "getConfig");
function isConvertibleWidget(widget, config) {
return (VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0])) && !widget.options?.forceInput;
}
__name(isConvertibleWidget, "isConvertibleWidget");
function hideWidget(node, widget, suffix = "") {
if (widget.type?.startsWith(CONVERTED_TYPE)) return;
widget.origType = widget.type;
widget.origComputeSize = widget.computeSize;
widget.origSerializeValue = widget.serializeValue;
widget.computeSize = () => [0, -4];
widget.type = CONVERTED_TYPE + suffix;
widget.serializeValue = () => {
if (!node.inputs) {
return void 0;
}
let node_input = node.inputs.find((i) => i.widget?.name === widget.name);
if (!node_input || !node_input.link) {
return void 0;
}
return widget.origSerializeValue ? widget.origSerializeValue() : widget.value;
};
if (widget.linkedWidgets) {
for (const w of widget.linkedWidgets) {
hideWidget(node, w, ":" + widget.name);
}
}
}
__name(hideWidget, "hideWidget");
function showWidget(widget) {
widget.type = widget.origType;
widget.computeSize = widget.origComputeSize;
widget.serializeValue = widget.origSerializeValue;
delete widget.origType;
delete widget.origComputeSize;
delete widget.origSerializeValue;
if (widget.linkedWidgets) {
for (const w of widget.linkedWidgets) {
showWidget(w);
}
}
}
__name(showWidget, "showWidget");
function convertToInput(node, widget, config) {
hideWidget(node, widget);
const { type } = getWidgetType(config);
const sz = node.size;
const inputIsOptional = !!widget.options?.inputIsOptional;
const input = node.addInput(widget.name, type, {
widget: { name: widget.name, [GET_CONFIG]: () => config },
...inputIsOptional ? { shape: LiteGraph.SlotShape.HollowCircle } : {}
});
for (const widget2 of node.widgets) {
widget2.last_y += LiteGraph.NODE_SLOT_HEIGHT;
}
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]);
return input;
}
__name(convertToInput, "convertToInput");
function convertToWidget(node, widget) {
showWidget(widget);
const sz = node.size;
node.removeInput(node.inputs.findIndex((i) => i.widget?.name === widget.name));
for (const widget2 of node.widgets) {
widget2.last_y -= LiteGraph.NODE_SLOT_HEIGHT;
}
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]);
}
__name(convertToWidget, "convertToWidget");
function getWidgetType(config) {
let type = config[0];
if (type instanceof Array) {
type = "COMBO";
}
return { type };
}
__name(getWidgetType, "getWidgetType");
function isValidCombo(combo, obj) {
if (!(obj instanceof Array)) {
console.log(`connection rejected: tried to connect combo to ${obj}`);
return false;
}
if (combo.length !== obj.length) {
console.log(`connection rejected: combo lists dont match`);
return false;
}
if (combo.find((v, i) => obj[i] !== v)) {
console.log(`connection rejected: combo lists dont match`);
return false;
}
return true;
}
__name(isValidCombo, "isValidCombo");
function isPrimitiveNode(node) {
return node.type === "PrimitiveNode";
}
__name(isPrimitiveNode, "isPrimitiveNode");
function setWidgetConfig(slot, config, target) {
if (!slot.widget) return;
if (config) {
slot.widget[GET_CONFIG] = () => config;
slot.widget[TARGET] = target;
} else {
delete slot.widget;
}
if (slot.link) {
const link = app.graph.links[slot.link];
if (link) {
const originNode = app.graph.getNodeById(link.origin_id);
if (isPrimitiveNode(originNode)) {
if (config) {
originNode.recreateWidget();
} else if (!app.configuringGraph) {
originNode.disconnectOutput(0);
originNode.onLastDisconnect();
}
}
}
}
}
__name(setWidgetConfig, "setWidgetConfig");
function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
if (!config1) {
config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
}
if (config1[0] instanceof Array) {
if (!isValidCombo(config1[0], config2[0])) return;
} else if (config1[0] !== config2[0]) {
console.log(`connection rejected: types dont match`, config1[0], config2[0]);
return;
}
const keys = /* @__PURE__ */ new Set([
...Object.keys(config1[1] ?? {}),
...Object.keys(config2[1] ?? {})
]);
let customConfig;
const getCustomConfig = /* @__PURE__ */ __name(() => {
if (!customConfig) {
if (typeof structuredClone === "undefined") {
customConfig = JSON.parse(JSON.stringify(config1[1] ?? {}));
} else {
customConfig = structuredClone(config1[1] ?? {});
}
}
return customConfig;
}, "getCustomConfig");
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
for (const k of keys.values()) {
if (k !== "default" && k !== "forceInput" && k !== "defaultInput" && k !== "control_after_generate" && k !== "multiline" && k !== "tooltip") {
let v1 = config1[1][k];
let v2 = config2[1]?.[k];
if (v1 === v2 || !v1 && !v2) continue;
if (isNumber) {
if (k === "min") {
const theirMax = config2[1]?.["max"];
if (theirMax != null && v1 > theirMax) {
console.log("connection rejected: min > max", v1, theirMax);
return;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2);
continue;
} else if (k === "max") {
const theirMin = config2[1]?.["min"];
if (theirMin != null && v1 < theirMin) {
console.log("connection rejected: max < min", v1, theirMin);
return;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2);
continue;
} else if (k === "step") {
let step;
if (v1 == null) {
step = v2;
} else if (v2 == null) {
step = v1;
} else {
if (v1 < v2) {
const a = v2;
v2 = v1;
v1 = a;
}
if (v1 % v2) {
console.log(
"connection rejected: steps not divisible",
"current:",
v1,
"new:",
v2
);
return;
}
step = v1;
}
getCustomConfig()[k] = step;
continue;
}
}
console.log(`connection rejected: config ${k} values dont match`, v1, v2);
return;
}
}
if (customConfig || forceUpdate) {
if (customConfig) {
output.widget[CONFIG] = [config1[0], customConfig];
}
const widget = recreateWidget?.call(this);
if (widget) {
const min = widget.options.min;
const max = widget.options.max;
if (min != null && widget.value < min) widget.value = min;
if (max != null && widget.value > max) widget.value = max;
widget.callback(widget.value);
}
}
return { customConfig };
}
__name(mergeIfValid, "mergeIfValid");
let useConversionSubmenusSetting;
app.registerExtension({
name: "Comfy.WidgetInputs",
init() {
useConversionSubmenusSetting = app.ui.settings.addSetting({
id: "Comfy.NodeInputConversionSubmenus",
name: "In the node context menu, place the entries that convert between input/widget in sub-menus.",
type: "boolean",
defaultValue: true
});
},
async beforeRegisterNodeDef(nodeType, nodeData, app2) {
const origGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions;
nodeType.prototype.convertWidgetToInput = function(widget) {
const config = getConfig.call(this, widget.name) ?? [
widget.type,
widget.options || {}
];
if (!isConvertibleWidget(widget, config)) return false;
if (widget.type?.startsWith(CONVERTED_TYPE)) return false;
convertToInput(this, widget, config);
return true;
};
nodeType.prototype.getExtraMenuOptions = function(_, options) {
const r = origGetExtraMenuOptions ? origGetExtraMenuOptions.apply(this, arguments) : void 0;
if (this.widgets) {
let toInput = [];
let toWidget = [];
for (const w of this.widgets) {
if (w.options?.forceInput) {
continue;
}
if (w.type === CONVERTED_TYPE) {
toWidget.push({
content: `Convert ${w.name} to widget`,
callback: /* @__PURE__ */ __name(() => convertToWidget(this, w), "callback")
});
} else {
const config = getConfig.call(this, w.name) ?? [
w.type,
w.options || {}
];
if (isConvertibleWidget(w, config)) {
toInput.push({
content: `Convert ${w.name} to input`,
callback: /* @__PURE__ */ __name(() => convertToInput(this, w, config), "callback")
});
}
}
}
if (toInput.length) {
if (useConversionSubmenusSetting.value) {
options.push({
content: "Convert Widget to Input",
submenu: {
options: toInput
}
});
} else {
options.push(...toInput, null);
}
}
if (toWidget.length) {
if (useConversionSubmenusSetting.value) {
options.push({
content: "Convert Input to Widget",
submenu: {
options: toWidget
}
});
} else {
options.push(...toWidget, null);
}
}
}
return r;
};
nodeType.prototype.onGraphConfigured = function() {
if (!this.inputs) return;
this.widgets ??= [];
for (const input of this.inputs) {
if (input.widget) {
if (!input.widget[GET_CONFIG]) {
input.widget[GET_CONFIG] = () => getConfig.call(this, input.widget.name);
}
if (input.widget.config) {
if (input.widget.config[0] instanceof Array) {
input.type = "COMBO";
const link = app2.graph.links[input.link];
if (link) {
link.type = input.type;
}
}
delete input.widget.config;
}
const w = this.widgets.find((w2) => w2.name === input.widget.name);
if (w) {
hideWidget(this, w);
} else {
convertToWidget(this, input);
}
}
}
};
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = function() {
const r = origOnNodeCreated ? origOnNodeCreated.apply(this) : void 0;
if (!app2.configuringGraph && this.widgets) {
for (const w of this.widgets) {
if (w?.options?.forceInput || w?.options?.defaultInput) {
const config = getConfig.call(this, w.name) ?? [
w.type,
w.options || {}
];
convertToInput(this, w, config);
}
}
}
return r;
};
const origOnConfigure = nodeType.prototype.onConfigure;
nodeType.prototype.onConfigure = function() {
const r = origOnConfigure ? origOnConfigure.apply(this, arguments) : void 0;
if (!app2.configuringGraph && this.inputs) {
for (const input of this.inputs) {
if (input.widget && !input.widget[GET_CONFIG]) {
input.widget[GET_CONFIG] = () => getConfig.call(this, input.widget.name);
const w = this.widgets.find((w2) => w2.name === input.widget.name);
if (w) {
hideWidget(this, w);
}
}
}
}
return r;
};
function isNodeAtPos(pos) {
for (const n of app2.graph.nodes) {
if (n.pos[0] === pos[0] && n.pos[1] === pos[1]) {
return true;
}
}
return false;
}
__name(isNodeAtPos, "isNodeAtPos");
const origOnInputDblClick = nodeType.prototype.onInputDblClick;
const ignoreDblClick = Symbol();
nodeType.prototype.onInputDblClick = function(slot) {
const r = origOnInputDblClick ? origOnInputDblClick.apply(this, arguments) : void 0;
const input = this.inputs[slot];
if (!input.widget || !input[ignoreDblClick]) {
if (!(input.type in ComfyWidgets) && !(input.widget[GET_CONFIG]?.()?.[0] instanceof Array)) {
return r;
}
}
const node = LiteGraph.createNode("PrimitiveNode");
app2.graph.add(node);
const pos = [
this.pos[0] - node.size[0] - 30,
this.pos[1]
];
while (isNodeAtPos(pos)) {
pos[1] += LiteGraph.NODE_TITLE_HEIGHT;
}
node.pos = pos;
node.connect(0, this, slot);
node.title = input.name;
input[ignoreDblClick] = true;
setTimeout(() => {
delete input[ignoreDblClick];
}, 300);
return r;
};
const onConnectInput = nodeType.prototype.onConnectInput;
nodeType.prototype.onConnectInput = function(targetSlot, type, output, originNode, originSlot) {
const v = onConnectInput?.(this, arguments);
if (type !== "COMBO") return v;
if (originNode.outputs[originSlot].widget) return v;
const targetCombo = this.inputs[targetSlot].widget?.[GET_CONFIG]?.()?.[0];
if (!targetCombo || !(targetCombo instanceof Array)) return v;
const originConfig = originNode.constructor?.nodeData?.output?.[originSlot];
if (!originConfig || !isValidCombo(targetCombo, originConfig)) {
return false;
}
return v;
};
},
registerCustomNodes() {
LiteGraph.registerNodeType(
"PrimitiveNode",
Object.assign(PrimitiveNode, {
title: "Primitive"
})
);
PrimitiveNode.category = "utils";
}
});
window.comfyAPI = window.comfyAPI || {};
window.comfyAPI.widgetInputs = window.comfyAPI.widgetInputs || {};
window.comfyAPI.widgetInputs.getWidgetConfig = getWidgetConfig;
window.comfyAPI.widgetInputs.convertToInput = convertToInput;
window.comfyAPI.widgetInputs.setWidgetConfig = setWidgetConfig;
window.comfyAPI.widgetInputs.mergeIfValid = mergeIfValid;
export {
convertToInput,
getWidgetConfig,
mergeIfValid,
setWidgetConfig
};
//# sourceMappingURL=widgetInputs-DdoWwzg5.js.map

1
web/assets/widgetInputs-DdoWwzg5.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

View File

@@ -1,2 +1,2 @@
// Shim for extensions\core\clipspace.ts // Shim for extensions/core/clipspace.ts
export const ClipspaceDialog = window.comfyAPI.clipspace.ClipspaceDialog; export const ClipspaceDialog = window.comfyAPI.clipspace.ClipspaceDialog;

3
web/extensions/core/colorPalette.js vendored Normal file
View File

@@ -0,0 +1,3 @@
// Shim for extensions/core/colorPalette.ts
export const defaultColorPalette = window.comfyAPI.colorPalette.defaultColorPalette;
export const getColorPalette = window.comfyAPI.colorPalette.getColorPalette;

View File

@@ -1,3 +1,3 @@
// Shim for extensions\core\groupNode.ts // Shim for extensions/core/groupNode.ts
export const GroupNodeConfig = window.comfyAPI.groupNode.GroupNodeConfig; export const GroupNodeConfig = window.comfyAPI.groupNode.GroupNodeConfig;
export const GroupNodeHandler = window.comfyAPI.groupNode.GroupNodeHandler; export const GroupNodeHandler = window.comfyAPI.groupNode.GroupNodeHandler;

View File

@@ -1,2 +1,2 @@
// Shim for extensions\core\groupNodeManage.ts // Shim for extensions/core/groupNodeManage.ts
export const ManageGroupDialog = window.comfyAPI.groupNodeManage.ManageGroupDialog; export const ManageGroupDialog = window.comfyAPI.groupNodeManage.ManageGroupDialog;

View File

@@ -1,4 +1,5 @@
// Shim for extensions\core\widgetInputs.ts // Shim for extensions/core/widgetInputs.ts
export const getWidgetConfig = window.comfyAPI.widgetInputs.getWidgetConfig; export const getWidgetConfig = window.comfyAPI.widgetInputs.getWidgetConfig;
export const convertToInput = window.comfyAPI.widgetInputs.convertToInput;
export const setWidgetConfig = window.comfyAPI.widgetInputs.setWidgetConfig; export const setWidgetConfig = window.comfyAPI.widgetInputs.setWidgetConfig;
export const mergeIfValid = window.comfyAPI.widgetInputs.mergeIfValid; export const mergeIfValid = window.comfyAPI.widgetInputs.mergeIfValid;

92
web/index.html vendored
View File

@@ -1,50 +1,42 @@
<!DOCTYPE html> <!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<title>ComfyUI</title> <title>ComfyUI</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no"> <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
<!-- Browser Test Fonts --> <link rel="stylesheet" type="text/css" href="user.css" />
<!-- <link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet"> <link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
<link href="https://fonts.googleapis.com/css2?family=Noto+Color+Emoji&family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet"> <script type="module" crossorigin src="./assets/index-DGAbdBYF.js"></script>
<style> <link rel="stylesheet" crossorigin href="./assets/index-BHJGjcJh.css">
* { </head>
font-family: 'Roboto Mono', 'Noto Color Emoji'; <body class="litegraph grid">
} <div id="vue-app"></div>
</style> --> <div id="comfy-user-selection" class="comfy-user-selection" style="display: none;">
<main class="comfy-user-selection-inner">
<h1>ComfyUI</h1>
<form>
<link rel="stylesheet" type="text/css" href="user.css" /> <section>
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" /> <label>New user:
<script type="module" crossorigin src="./assets/index-CI3N807S.js"></script> <input placeholder="Enter a username" />
<link rel="stylesheet" crossorigin href="./assets/index-_5czGnTA.css"> </label>
</head> </section>
<body class="litegraph"> <div class="comfy-user-existing">
<div id="vue-app"></div> <span class="or-separator">OR</span>
<div id="comfy-user-selection" class="comfy-user-selection" style="display: none;"> <section>
<main class="comfy-user-selection-inner"> <label>
<h1>ComfyUI</h1> Existing user:
<form> <select>
<section> <option hidden disabled selected value> Select a user </option>
<label>New user: </select>
<input placeholder="Enter a username" /> </label>
</label> </section>
</section> </div>
<div class="comfy-user-existing"> <footer>
<span class="or-separator">OR</span> <span class="comfy-user-error">&nbsp;</span>
<section> <button class="comfy-btn comfy-user-button-next">Next</button>
<label> </footer>
Existing user: </form>
<select> </main>
<option hidden disabled selected value> Select a user </option> </div>
</select> </body>
</label> </html>
</section>
</div>
<footer>
<span class="comfy-user-error">&nbsp;</span>
<button class="comfy-btn comfy-user-button-next">Next</button>
</footer>
</form>
</main>

Some files were not shown because too many files have changed in this diff Show More