Compare commits

...

271 Commits

Author SHA1 Message Date
Chenlei Hu
65a8659182 Update web content to release v1.3.26 (#5413)
* Update web content to release v1.3.26

* nit
2024-10-29 14:14:06 -04:00
comfyanonymous
770ab200f2 Cleanup SkipLayerGuidanceSD3 node. 2024-10-29 10:11:46 -04:00
Dango233
954683d0db SLG first implementation for SD3.5 (#5404)
* SLG first implementation for SD3.5

* * Simplify and align with comfy style
2024-10-29 09:59:21 -04:00
comfyanonymous
30c0c81351 Add a way to patch blocks in SD3. 2024-10-29 00:48:32 -04:00
comfyanonymous
13b0ff8a6f Update SD3 code. 2024-10-28 21:58:52 -04:00
comfyanonymous
c320801187 Remove useless line. 2024-10-28 17:41:12 -04:00
Chenlei Hu
c0b0cfaeec Update web content to release v1.3.21 (#5351)
* Update web content to release v1.3.21

* nit
2024-10-28 14:29:38 -04:00
comfyanonymous
669d9e4c67 Set default shift on mochi to 6.0 2024-10-27 22:21:04 -04:00
comfyanonymous
9ee0a6553a float16 inference is a bit broken on mochi. 2024-10-27 04:56:40 -04:00
comfyanonymous
5cbb01bc2f Basic Genmo Mochi video model support.
To use:
"Load CLIP" node with t5xxl + type mochi
"Load Diffusion Model" node with the mochi dit file.
"Load VAE" with the mochi vae file.

EmptyMochiLatentVideo node for the latent.
euler + linear_quadratic in the KSampler node.
2024-10-26 06:54:00 -04:00
comfyanonymous
c3ffbae067 Make LatentUpscale nodes work on 3d latents. 2024-10-26 01:50:51 -04:00
comfyanonymous
d605677b33 Make euler_ancestral work on flow models (credit: Ashen). 2024-10-25 19:53:44 -04:00
Chenlei Hu
ce759b7db6 Revert download to .tmp in frontend_management (#5369) 2024-10-25 19:26:13 -04:00
comfyanonymous
52810907e2 Add a model merge node for SD3.5 large. 2024-10-24 16:46:21 -04:00
PsychoLogicAu
af8cf79a2d support SimpleTuner lycoris lora for SD3 (#5340) 2024-10-24 01:18:32 -04:00
comfyanonymous
66b0961a46 Fix ControlLora issue with last commit. 2024-10-23 17:02:40 -04:00
comfyanonymous
754597c8a9 Clean up some controlnet code.
Remove self.device which was useless.
2024-10-23 14:19:05 -04:00
comfyanonymous
915fdb5745 Fix lowvram edge case. 2024-10-22 16:34:50 -04:00
contentis
5a8a48931a remove attention abstraction (#5324) 2024-10-22 14:02:38 -04:00
comfyanonymous
8ce2a1052c Optimizations to --fast and scaled fp8. 2024-10-22 02:12:28 -04:00
comfyanonymous
f82314fcfc Fix duplicate sigmas on beta scheduler. 2024-10-21 20:19:45 -04:00
comfyanonymous
0075c6d096 Mixed precision diffusion models with scaled fp8.
This change allows supports for diffusion models where all the linears are
scaled fp8 while the other weights are the original precision.
2024-10-21 18:12:51 -04:00
comfyanonymous
83ca891118 Support scaled fp8 t5xxl model. 2024-10-20 22:27:00 -04:00
comfyanonymous
f9f9faface Fixed model merging issue with scaled fp8. 2024-10-20 06:24:31 -04:00
comfyanonymous
471cd3eace fp8 casting is fast on GPUs that support fp8 compute. 2024-10-20 00:54:47 -04:00
comfyanonymous
a68bbafddb Support diffusion models with scaled fp8 weights. 2024-10-19 23:47:42 -04:00
comfyanonymous
73e3a9e676 Clamp output when rounding weight to prevent Nan. 2024-10-19 19:07:10 -04:00
comfyanonymous
518c0dc2fe Add tooltips to LoraSave node. 2024-10-18 06:01:09 -04:00
comfyanonymous
ce0542e10b Add a note that python 3.13 is not yet supported to the README. 2024-10-17 19:27:37 -04:00
comfyanonymous
8473019d40 Pytorch can be shipped with numpy 2 now. 2024-10-17 19:15:17 -04:00
Xiaodong Xie
89f15894dd Ignore more network related errors during websocket communication. (#5269)
Intermittent network issues during websocket communication should not crash ComfyUi process.

Co-authored-by: Xiaodong Xie <xie.xiaodong@frever.com>
2024-10-17 18:31:45 -04:00
comfyanonymous
67158994a4 Use the lowvram cast_to function for everything. 2024-10-17 17:25:56 -04:00
comfyanonymous
7390ff3b1e Add missing import. 2024-10-16 14:58:30 -04:00
comfyanonymous
0bedfb26af Revert "Fix Transformers FutureWarning (#5140)"
This reverts commit 95b7cf9bbe.
2024-10-16 12:36:19 -04:00
comfyanonymous
f71cfd2687 Add an experimental node to sharpen latents.
Can be used with LatentApplyOperationCFG for interesting results.
2024-10-16 05:25:31 -04:00
Alex "mcmonkey" Goodwin
c695c4af7f Frontend Manager: avoid redundant gh calls for static versions (#5152)
* Frontend Manager: avoid redundant gh calls for static versions

* actually, removing old tmpdir isn't needed

I tested - downloader code handles this case well already
(also rmdir was wrong func anyway, needed shutil.rmtree if it had content)

* add code comment
2024-10-16 03:35:37 -04:00
comfyanonymous
0dbba9f751 Add some latent operation nodes.
This is a port of the ModelSamplerTonemapNoiseTest from the experiments
repo.

To replicate that node use LatentOperationTonemapReinhard and
LatentApplyOperationCFG together.
2024-10-15 15:00:36 -04:00
comfyanonymous
f584758271 Cleanup some useless lines. 2024-10-14 21:02:39 -04:00
svdc
95b7cf9bbe Fix Transformers FutureWarning (#5140)
* Update sd1_clip.py

Fix Transformers FutureWarning

* Update sd1_clip.py

Fix comment
2024-10-14 20:12:20 -04:00
comfyanonymous
191a0d56b4 Switch default packaging workflows to python 3.12 2024-10-13 06:59:31 -04:00
comfyanonymous
3c60ecd7a8 Fix fp8 ops staying enabled. 2024-10-12 14:10:13 -04:00
comfyanonymous
7ae6626723 Remove useless argument. 2024-10-12 07:16:21 -04:00
comfyanonymous
6632365e16 model_options consistency between functions.
weight_dtype -> dtype
2024-10-11 20:51:19 -04:00
Kadir Nar
ad07796777 🐛 Add device to variable c (#5210) 2024-10-11 20:37:50 -04:00
comfyanonymous
1b80895285 Make clip loader nodes support loading sd3 t5xxl in lower precision.
Add attention mask support in the SD3 text encoder code.
2024-10-10 15:06:15 -04:00
Dr.Lt.Data
5f9d5a244b Hotfix for the div zero occurrence when memory_used_encode is 0 (#5121)
https://github.com/comfyanonymous/ComfyUI/issues/5069#issuecomment-2382656368
2024-10-09 23:34:34 -04:00
Chenlei Hu
14eba07acd Update web content to release v1.3.11 (#5189)
* Update web content to release v1.3.11

* nit
2024-10-09 22:37:04 -04:00
Jonathan Avila
4b2f0d9413 Increase maximum macOS version to 15.0.1 when forcing upcast attention (#5191) 2024-10-09 22:21:41 -04:00
Yoland Yan
25eac1d780 Change runner label for the new runners (#5197) 2024-10-09 20:08:57 -04:00
comfyanonymous
e38c94228b Add a weight_dtype fp8_e4m3fn_fast to the Diffusion Model Loader node.
This is used to load weights in fp8 and use fp8 matrix multiplication.
2024-10-09 19:43:17 -04:00
comfyanonymous
203942c8b2 Fix flux doras with diffusers keys. 2024-10-08 19:03:40 -04:00
Brendan Hoar
3c72c89a52 Update folder_paths.py - try/catch for special file_name values (#5187)
Somehow managed to drop a file called "nul" into a windows checkpoints subdirectory. This caused all sorts of havoc with many nodes that needed the list of checkpoints.
2024-10-08 15:04:32 -04:00
Chenlei Hu
614377abd6 Update web content to release v1.2.64 (#5124) 2024-10-07 17:15:29 -04:00
comfyanonymous
8dfa0cc552 Make SD3 fast previews a little better. 2024-10-07 09:19:59 -04:00
comfyanonymous
e5ecdfdd2d Make fast previews for SDXL a little better by adding a bias. 2024-10-06 19:27:04 -04:00
comfyanonymous
7d29fbf74b Slightly improve the fast previews for flux by adding a bias. 2024-10-06 17:55:46 -04:00
Lex
2c641e64ad IS_CHANGED should be a classmethod (#5159) 2024-10-06 05:47:51 -04:00
comfyanonymous
7d2467e830 Some minor cleanups. 2024-10-05 13:22:39 -04:00
comfyanonymous
6f021d8aa0 Let --verbose have an argument for the log level. 2024-10-04 10:05:34 -04:00
comfyanonymous
d854ed0bcf Allow using SD3 type te output on flux model. 2024-10-03 09:44:54 -04:00
comfyanonymous
abcd006b8c Allow more permutations of clip/t5 in dual clip loader. 2024-10-03 09:26:11 -04:00
comfyanonymous
d985d1d7dc CLIP Loader node now supports clip_l and clip_g only for SD3. 2024-10-02 04:25:17 -04:00
comfyanonymous
d1cdf51e1b Refactor some of the TE detection code. 2024-10-01 07:08:41 -04:00
comfyanonymous
b4626ab93e Add simpletuner lycoris format for SD unet. 2024-09-30 06:03:27 -04:00
comfyanonymous
a9e459c2a4 Use torch.nn.functional.linear in RGB preview code.
Add an optional bias to the latent RGB preview code.
2024-09-29 11:27:49 -04:00
comfyanonymous
3bb4dec720 Fix issue with loras, lowvram and --fast fp8. 2024-09-28 14:42:32 -04:00
City
8733191563 Flux torch.compile fix (#5082) 2024-09-27 22:07:51 -04:00
comfyanonymous
83b01f960a Add backend option to TorchCompileModel.
If you want to use the cudagraphs backend you need to: --disable-cuda-malloc

If you get other backends working feel free to make a PR to add them.
2024-09-27 02:12:37 -04:00
comfyanonymous
d72e871cfa Add a note that the experimental model downloader api will be removed. 2024-09-26 03:17:52 -04:00
comfyanonymous
037c3159b6 Move some nodes out of _for_testing. 2024-09-25 08:41:22 -04:00
comfyanonymous
bdd4a22a2e Fix flux TE not loading t5 embeddings. 2024-09-24 22:57:22 -04:00
comfyanonymous
fdf37566ef Add batch size to EmptyLatentAudio. 2024-09-24 04:32:55 -04:00
Alex "mcmonkey" Goodwin
08c8968482 Internal download API: Add proper validated directory input (#4981)
* add internal /folder_paths route

returns a json maps of folder paths

* (minor) format download_models.py

* initial folder path input on download api

* actually, require folder_path and clean up some code

* partial tests update

* fix & logging

* also download to a tmp file not the live file

to avoid compounding errors from network failure

* update tests again

* test tweaks

* workaround the first tests blocker

* fix file handling in tests

* rewrite test for create_model_path

* minor doc fix

* avoid 'mock_directory'

use temp dir to avoid accidental fs pollution from tests
2024-09-24 03:50:45 -04:00
chaObserv
479a427a48 Add dpmpp_2m_cfg_pp (#4992) 2024-09-24 02:42:56 -04:00
comfyanonymous
3a0eeee320 Make --listen listen on both ipv4 and ipv6 at the same time by default. 2024-09-23 04:38:19 -04:00
comfyanonymous
447da7ea86 Support listening on multiple addresses. 2024-09-23 04:36:59 -04:00
comfyanonymous
9c41bc8d10 Remove useless line. 2024-09-23 02:32:29 -04:00
Robin Huang
6ad0ddbae4 Run unit tests on Windows/MacOS as well. (#5018)
* Run unit tests on Windows as well.

* Test on mac.

* Continue running on error.

* Compared normalized paths to work cross platform.

* Only test common set of mimetypes across operating systems.
2024-09-22 05:01:39 -04:00
RandomGitUser321
a55142f904 Add ws.close() to the websocket examples (#5020)
* add ws.close() to websocket examples

* add and explain ws.close() in websocket examples
2024-09-22 04:59:10 -04:00
comfyanonymous
5718ef69bb Add total and free ram to /system_stats. 2024-09-22 03:42:11 -04:00
RandomGitUser321
13ecf10a92 Added to the websockets_api_example.py to show how to decode latent previews from the binary stream (#5016)
* Update websockets_api_example.py

* even more simplfied
2024-09-22 02:30:44 -04:00
comfyanonymous
7a415f47a9 Add an optional VAE input to the ControlNetApplyAdvanced node.
Deprecate the other controlnet nodes.
2024-09-22 01:24:52 -04:00
Chenlei Hu
89fa2fca24 Update web content to release v1.2.60 (#5017)
* Update web content to release v1.2.60

* Remove dist.zip
2024-09-21 23:28:54 -04:00
comfyanonymous
364b69e931 Make SD3 empty latent image zeros.
This shouldn't change anything. The reason it was not zeros is because it
did matter in early versions of the code.
2024-09-21 09:13:10 -04:00
comfyanonymous
dc96a1ae19 Load controlnet in fp8 if weights are in fp8. 2024-09-21 04:50:12 -04:00
comfyanonymous
2d810b081e Add load_controlnet_state_dict function. 2024-09-21 01:51:51 -04:00
comfyanonymous
9f7e9f0547 Add an error message when a controlnet needs a VAE but none is given. 2024-09-21 01:33:18 -04:00
comfyanonymous
a355f38ecc Make the SD3 controlnet node the default one. 2024-09-21 01:32:46 -04:00
huchenlei
38c69080c7 Add docstring 2024-09-20 03:16:23 -04:00
comfyanonymous
70a708d726 Fix model merging issue. 2024-09-20 02:31:44 -04:00
yoinked
e7d4782736 add laplace scheduler [2407.03297] (#4990)
* add laplace scheduler [2407.03297]

* should be here instead lol

* better settings
2024-09-19 23:23:09 -04:00
Alex "mcmonkey" Goodwin
3326bdfd4e add internal /folder_paths route (#4980)
returns a json maps of folder paths
2024-09-19 09:52:55 -04:00
Alex "mcmonkey" Goodwin
68bb885d22 add 'is_default' to model paths config (#4979)
* add 'is_default' to model paths config

including impl and doc in example file

* update weirdly overspecific test expectations

* oh there's two

* sigh
2024-09-19 08:59:55 -04:00
comfyanonymous
ad66f7c7d8 Add model_options to load_controlnet function. 2024-09-19 08:23:35 -04:00
Simon Lui
de8e8e3b0d Fix xpu Pytorch nightly build from calling optimize which doesn't exist. (#4978) 2024-09-19 05:11:42 -04:00
Alex "mcmonkey" Goodwin
a1e71cfad1 very simple strong-cache on model list (#4969)
* very simple strong-cache on model list

* store the cache after validation too

* only cache object_info for now

* use a 'with' context
2024-09-19 04:40:14 -04:00
comfyanonymous
0bfc7cc998 Create the temp directory on ComfyUI startup instead. 2024-09-18 09:55:57 -04:00
Tom
7183fd1665 Add route to list model types (#4846)
* Add list models route

* Better readable model types list
2024-09-17 04:22:05 -04:00
Alex "mcmonkey" Goodwin
254838f23c add simple error check to model loading (#4950) 2024-09-17 03:57:17 -04:00
pharmapsychotic
0b7dfa986d Improve tiling calculations to reduce number of tiles that need to be processed. (#4944) 2024-09-17 03:51:10 -04:00
comfyanonymous
d514bb38ee Add some option to model_options for the text encoder.
load_device, offload_device and the initial_device can now be set.
2024-09-17 03:49:54 -04:00
comfyanonymous
0849c80e2a get_key_patches now works without unloading the model. 2024-09-17 01:57:59 -04:00
comfyanonymous
56e8f5e4fd VAEDecodeAudio now does some normalization on the audio. 2024-09-16 00:30:36 -04:00
comfyanonymous
e813abbb2c Long CLIP L support for SDXL, SD3 and Flux.
Use the *CLIPLoader nodes.
2024-09-15 07:59:38 -04:00
JettHu
5e68a4ce67 Reduce repeated calls of INPUT_TYPES in cache (#4922) 2024-09-15 01:03:09 -04:00
comfyanonymous
ca08597670 Make the inpaint controlnet node work with non inpaint ones. 2024-09-14 09:17:13 -04:00
comfyanonymous
f48e390032 Support AliMama SD3 and Flux inpaint controlnets.
Use the ControlNetInpaintingAliMamaApply node.
2024-09-14 09:05:16 -04:00
Chenlei Hu
369a6dd2c4 Remove empty spaces in user_manager.py (#4917) 2024-09-13 23:30:44 -04:00
comfyanonymous
b3ce8fb9fd Revert "Reduce repeated calls of get_immediate_node_signature for ancestors in cache (#4871)"
This reverts commit f6b7194f64.
2024-09-13 23:24:47 -04:00
comfyanonymous
cf80d28689 Support loading controlnets with different input. 2024-09-13 09:54:37 -04:00
Acly
6fb44c4b7c Make adding links/nodes to ExecutionList non-recursive (#4886)
Graphs with 300+ chained nodes run into maximum recursion depth error (limit is 1000 in CPython)
2024-09-13 08:25:11 -04:00
Chenlei Hu
d2247c1e61 Normalize path returned by /userdata to always use / as separator (#4906) 2024-09-13 03:45:31 -04:00
Chenlei Hu
cb12ad7049 Add full_info flag in /userdata endpoint to list out file size and last modified timestamp (#4905)
* Add full_info flag in /userdata endpoint to list out file size and last modified timestamp

* nit
2024-09-13 02:40:59 -04:00
JettHu
f6b7194f64 Reduce repeated calls of get_immediate_node_signature for ancestors in cache (#4871) 2024-09-12 23:02:52 -04:00
comfyanonymous
7c6eb4fb29 Set some nodes as DEPRECATED. 2024-09-12 20:27:07 -04:00
Robin Huang
b962db9952 Add cli arg to override user directory (#4856)
* Override user directory.

* Use overridden user directory.

* Remove prints.

* Remove references to global user_files.

* Remove unused replace_folder function.

* Remove newline.

* Remove global during get_user_directory.

* Add validation.
2024-09-12 08:10:27 -04:00
comfyanonymous
d0b7ab88ba Add a simple experimental TorchCompileModel node.
It probably only works on Linux.

For maximum speed on Flux with Nvidia 40 series/ada and newer try using
this node with fp8_e4m3fn and the --fast argument.
2024-09-12 05:24:25 -04:00
Yoland Yan
405b529545 Minor: update tests-unit README.md (#4896) 2024-09-12 04:53:08 -04:00
comfyanonymous
9d720187f1 types -> comfy_types to fix import issue. 2024-09-12 03:57:46 -04:00
Robin Huang
d247bc5a9c Expand variables in base_path for extra_config_paths.yaml. (#4893)
* Expand variables in base_path for extra_config_paths.yaml.

* Fix comments.
2024-09-12 01:52:06 -04:00
comfyanonymous
9f4daca9d9 Doesn't really make sense for cfg_pp sampler to call regular one. 2024-09-11 02:51:36 -04:00
yoinked
b5d0f2a908 Add CFG++ to DPM++ 2S Ancestral (#3871)
* Update sampling.py

* Update samplers.py

* my bad

* "fix" the sampler

* Update samplers.py

* i named it wrong

* minor sampling improvements

mainly using a dynamic rho value (hey this sounds a lot like smea!!!)

* revert rho change

rho? r? its just 1/2
2024-09-11 02:49:44 -04:00
bymyself
e760bf5c40 Add content-type filter method to folder_paths (#4054)
* Add content-type filter method to folder_paths

* Add unit tests

* Hardcode webp content-type

* Annotate content_types as Literal["image", "video", "audio"]
2024-09-11 02:00:07 -04:00
comfyanonymous
36c83cdbba Limit origin check to when host is loopback.
This should still prevent the exploit without breaking things for people
who use reverse proxies.
2024-09-11 01:06:37 -04:00
Yoland Yan
81778a7feb [🗻 Mount Fuji Commit] Add unit tests for folder path utilities (#4869)
All past 30 min of comtts are done on the top of Mt Fuji
By Comfy, Robin, and Yoland
All other comfy org members died on the way

Introduced unit tests to verify the correctness of various folder path
utility functions such as `get_directory_by_type`, `annotated_filepath`,
and `recursive_search` among others. These tests cover scenarios
including directory retrieval, filepath annotation, recursive file
searches, and filtering files by extensions, enhancing the robustness
and reliability of the codebase.
2024-09-10 00:44:49 -04:00
comfyanonymous
bc94662b31 Cleanup. 2024-09-10 00:43:37 -04:00
Robin Huang
9fa8faa44a Expand user directory for basepath in extra_models_paths.yaml (#4857)
* Expand user path.

* Add test.

* Add unit test for expanding base path.

* Simplify unit test.

* Remove comment.

* Remove comment.

* Checkpoints.

* Refactor.
2024-09-10 00:33:44 -04:00
comfyanonymous
9a7444e39f Add diffusion_models to the extra_model_paths.yaml.example 2024-09-10 00:21:33 -04:00
comfyanonymous
54fca4a218 If host does not contain a port only compare the hostnames. 2024-09-09 16:28:23 -04:00
Chenlei Hu
cd4955367e Add back CI action for tests-ui (#4859) 2024-09-09 04:32:55 -04:00
david02871
8354203d95 Add .venv to gitignore (#4756) 2024-09-09 04:31:18 -04:00
comfyanonymous
e0b41243b4 Fix issue where sometimes origin doesn't contain the port. 2024-09-09 03:18:17 -04:00
Alex "mcmonkey" Goodwin
619263d4a6 allow current timestamp in save image prefix (#4030) 2024-09-09 02:55:51 -04:00
comfyanonymous
e3b0402bb7 Ignore origin domain when it's empty. 2024-09-09 01:04:56 -04:00
Darion
967867d48c fix: url decode filename from API (#4801) 2024-09-08 21:02:32 -04:00
comfyanonymous
cbaac71bf5 Fix issue with last commit. 2024-09-08 19:35:23 -04:00
comfyanonymous
3ab3516e46 By default only accept requests where origin header matches the host.
Browsers are dumb and let any website do requests to localhost this should
prevent this without breaking things. CORS prevents the javascript from
reading the response but they can still write it.

At the moment this is only enabled when the --enable-cors-header argument
is not used.
2024-09-08 18:17:29 -04:00
comfyanonymous
9c5fca75f4 Fix lora issue. 2024-09-08 10:10:47 -04:00
guill
a5da4d0b3e Fix error with ExecutionBlocker and OUTPUT_IS_LIST (#4836)
This change resolves an error when a node with OUTPUT_IS_LIST=(True,)
receives an ExecutionBlocker. I've also added a unit test for this case.
2024-09-08 09:48:47 -04:00
comfyanonymous
32a60a7bac Support onetrainer text encoder Flux lora. 2024-09-08 09:31:41 -04:00
Jim Winkens
bb52934ba4 Fix import issue (#4815) 2024-09-07 05:28:32 -04:00
comfyanonymous
8aabd7c8c0 SaveLora node can now save "full diff" lora format.
This isn't actually a lora format and is saving the full diff of the
weights in a format that can be used in the lora loader nodes.
2024-09-07 03:21:02 -04:00
comfyanonymous
a09b29ca11 Add an option to the SaveLora node to store the bias diff. 2024-09-07 03:03:30 -04:00
comfyanonymous
9bfee68773 LoraSave node now supports generating text encoder loras.
text_encoder_diff should be connected to a CLIPMergeSubtract node.

model_diff and text_encoder_diff are optional inputs so you can create
model only loras, text encoder only loras or a lora that contains both.
2024-09-07 02:30:12 -04:00
comfyanonymous
ea77750759 Support a generic Comfy format for text encoder loras.
This is a format with keys like:
text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.lora_up.weight

Instead of waiting for me to add support for specific lora formats you can
convert your text encoder loras to this format instead.

If you want to see an example save a text encoder lora with the SaveLora
node with the commit right after this one.
2024-09-07 02:20:39 -04:00
comfyanonymous
c27ebeb1c2 Fix onnx export not working on flux. 2024-09-06 03:21:52 -04:00
guill
0c7c98a965 Nodes using UNIQUE_ID as input are NOT_IDEMPOTENT (#4793)
As suggested by @ltdrdata, we can automatically consider nodes that take
the UNIQUE_ID hidden input to be NOT_IDEMPOTENT.
2024-09-05 19:33:02 -04:00
comfyanonymous
dc2eb75b85 Update stable release workflow to latest pytorch with cuda 12.4. 2024-09-05 19:21:52 -04:00
Chenlei Hu
fa34efe3bd Update frontend to v1.2.47 (#4798)
* Update web content to release v1.2.47

* Update shortcut list
2024-09-05 18:56:01 -04:00
comfyanonymous
5cbaa9e07c Mistoline flux controlnet support. 2024-09-05 00:05:17 -04:00
comfyanonymous
c7427375ee Prioritize freeing partially offloaded models first. 2024-09-04 19:47:32 -04:00
comfyanonymous
22d1241a50 Add an experimental LoraSave node to extract model loras.
The model_diff input should be connected to the output of a
ModelMergeSubtract node.
2024-09-04 16:38:38 -04:00
Jedrzej Kosinski
f04229b84d Add emb_patch support to UNetModel forward (#4779) 2024-09-04 14:35:15 -04:00
Silver
f067ad15d1 Make live preview size a configurable launch argument (#4649)
* Make live preview size a configurable launch argument

* Remove import from testing phase

* Update cli_args.py
2024-09-03 19:16:38 -04:00
comfyanonymous
483004dd1d Support newer glora format. 2024-09-03 17:02:19 -04:00
comfyanonymous
00a5d08103 Lower fp8 lora memory usage. 2024-09-03 01:25:05 -04:00
comfyanonymous
d043997d30 Flux onetrainer lora. 2024-09-02 08:22:15 -04:00
Alex "mcmonkey" Goodwin
f1c2301697 fix typo in stale-issues (#4735) 2024-09-01 17:44:49 -04:00
comfyanonymous
8d31a6632f Speed up inference on nvidia 10 series on Linux. 2024-09-01 17:29:31 -04:00
comfyanonymous
b643eae08b Make minimum_inference_memory() depend on --reserve-vram 2024-09-01 01:18:34 -04:00
comfyanonymous
baa6b4dc36 Update manual install instructions. 2024-08-31 04:37:23 -04:00
Alex "mcmonkey" Goodwin
d4aeefc297 add github action to automatically handle stale user support issues (#4683)
* add github action to automatically handle stale user support issues

* improve stale message

* remove token part
2024-08-31 01:57:18 -04:00
comfyanonymous
587e7ca654 Remove github buttons. 2024-08-31 01:53:10 -04:00
Chenlei Hu
c90459eba0 Update ComfyUI_frontend to 1.2.40 (#4691)
* Update ComfyUI_frontend to 1.2.40

* Add files
2024-08-30 19:32:10 -04:00
Vedat Baday
04278afb10 feat: return import_failed from init_extra_nodes function (#4694) 2024-08-30 19:26:47 -04:00
comfyanonymous
935ae153e1 Cleanup. 2024-08-30 12:53:59 -04:00
Chenlei Hu
e91662e784 Get logs endpoint & system_stats additions (#4690)
* Add route for getting output logs

* Include ComfyUI version

* Move to own function

* Changed to memory logger

* Unify logger setup logic

* Fix get version git fallback

---------

Co-authored-by: pythongosssss <125205205+pythongosssss@users.noreply.github.com>
2024-08-30 12:46:37 -04:00
comfyanonymous
63fafaef45 Fix potential issue with hydit controlnets. 2024-08-30 04:58:41 -04:00
Alex "mcmonkey" Goodwin
ec28cd9136 swap legacy sdv15 link (#4682)
* swap legacy sdv15 link

* swap v15 ckpt examples to safetensors

* link the fp16 copy of the model by default
2024-08-29 19:48:48 -04:00
comfyanonymous
6eb5d64522 Fix glora lowvram issue. 2024-08-29 19:07:23 -04:00
comfyanonymous
10a79e9898 Implement model part of flux union controlnet. 2024-08-29 18:41:22 -04:00
comfyanonymous
ea3f39bd69 InstantX depth flux controlnet. 2024-08-29 02:14:19 -04:00
comfyanonymous
b33cd61070 InstantX canny controlnet. 2024-08-28 19:02:50 -04:00
Dr.Lt.Data
34eda0f853 fix: remove redundant useless loop (#4656)
fix: potential error of undefined variable

https://github.com/comfyanonymous/ComfyUI/discussions/4650
2024-08-28 17:46:30 -04:00
comfyanonymous
d31e226650 Unify RMSNorm code. 2024-08-28 16:56:38 -04:00
comfyanonymous
b79fd7d92c ComfyUI supports more than just stable diffusion. 2024-08-28 16:12:24 -04:00
comfyanonymous
38c22e631a Fix case where model was not properly unloaded in merging workflows. 2024-08-27 19:03:51 -04:00
Chenlei Hu
6bbdcd28ae Support weight padding on diff weight patch (#4576) 2024-08-27 13:55:37 -04:00
comfyanonymous
ab130001a8 Do RMSNorm in native type. 2024-08-27 02:41:56 -04:00
Chenlei Hu
ca4b8f30e0 Cleanup empty dir if frontend zip download failed (#4574) 2024-08-27 02:07:25 -04:00
Robin Huang
70b84058c1 Add relative file path to the progress report. (#4621) 2024-08-27 02:06:12 -04:00
comfyanonymous
2ca8f6e23d Make the stochastic fp8 rounding reproducible. 2024-08-26 15:12:06 -04:00
comfyanonymous
7985ff88b9 Use less memory in float8 lora patching by doing calculations in fp16. 2024-08-26 14:45:58 -04:00
comfyanonymous
c6812947e9 Fix potential memory leak. 2024-08-26 02:07:32 -04:00
comfyanonymous
9230f65823 Fix some controlnets OOMing when loading. 2024-08-25 05:54:29 -04:00
guill
6ab1e6fd4a [Bug #4529] Fix graph partial validation failure (#4588)
Currently, if a graph partially fails validation (i.e. some outputs are
valid while others have links from missing nodes), the execution loop
could get an exception resulting in server lockup.

This isn't actually possible to reproduce via the default UI, but is a
potential issue for people using the API to construct invalid graphs.
2024-08-24 15:34:58 -04:00
comfyanonymous
07dcbc3a3e Clarify how to use high quality previews. 2024-08-24 02:31:03 -04:00
comfyanonymous
8ae23d8e80 Fix onnx export. 2024-08-23 17:52:47 -04:00
comfyanonymous
7df42b9a23 Fix dora. 2024-08-23 04:58:59 -04:00
comfyanonymous
5d8bbb7281 Cleanup. 2024-08-23 04:06:27 -04:00
comfyanonymous
2c1d2375d6 Fix. 2024-08-23 04:04:55 -04:00
Simon Lui
64ccb3c7e3 Rework IPEX check for future inclusion of XPU into Pytorch upstream and do a bit more optimization of ipex.optimize(). (#4562) 2024-08-23 03:59:57 -04:00
Scorpinaus
9465b23432 Added SD15_Inpaint_Diffusers model support for unet_config_from_diffusers_unet function (#4565) 2024-08-23 03:57:08 -04:00
Chenlei Hu
bb4416dd5b Fix task.status.status_str caused by #2666 (#4551)
* Fix task.status.status_str caused by 2666 regression

* fix

* fix
2024-08-22 17:38:30 -04:00
comfyanonymous
c0b0da264b Missing imports. 2024-08-22 17:20:51 -04:00
comfyanonymous
c26ca27207 Move calculate function to comfy.lora 2024-08-22 17:12:00 -04:00
comfyanonymous
7c6bb84016 Code cleanups. 2024-08-22 17:05:12 -04:00
comfyanonymous
c54d3ed5e6 Fix issue with models staying loaded in memory. 2024-08-22 15:58:20 -04:00
comfyanonymous
c7ee4b37a1 Try to fix some lora issues. 2024-08-22 15:32:18 -04:00
David
7b70b266d8 Generalize MacOS version check for force-upcast-attention (#4548)
This code automatically forces upcasting attention for MacOS versions 14.5 and 14.6. My computer returns the string "14.6.1" for `platform.mac_ver()[0]`, so this generalizes the comparison to catch more versions.

I am running MacOS Sonoma 14.6.1 (latest version) and was seeing black image generation on previously functional workflows after recent software updates. This PR solved the issue for me.

See comfyanonymous/ComfyUI#3521
2024-08-22 13:24:21 -04:00
comfyanonymous
8f60d093ba Fix issue. 2024-08-22 10:38:24 -04:00
guill
dafbe321d2 Fix a bug where cached outputs affected IS_CHANGED (#4535)
This change fixes a bug where non-constant values could be passed to the
IS_CHANGED function. This would result in workflows taking an extra
execution before they acted as if they were cached.

The actual change is like 4 characters -- the rest is adding unit tests.
2024-08-21 23:38:46 -04:00
comfyanonymous
5f84ea63e8 Add a shortcut to the nightly package to run with --fast. 2024-08-21 23:36:58 -04:00
comfyanonymous
843a7ff70c fp16 is actually faster than fp32 on a GTX 1080. 2024-08-21 23:23:50 -04:00
comfyanonymous
a60620dcea Fix slow performance on 10 series Nvidia GPUs. 2024-08-21 16:39:02 -04:00
comfyanonymous
015f73dc49 Try a different type of flux fp16 fix. 2024-08-21 16:17:15 -04:00
comfyanonymous
904bf58e7d Make --fast work on pytorch nightly. 2024-08-21 14:01:41 -04:00
Svein Ove Aas
5f50263088 Replace use of .view with .reshape (#4522)
When generating images with fp8_e4_m3 Flux and batch size >1, using --fast, ComfyUI throws a "view size is not compatible with input tensor's size and stride" error pointing at the first of these two calls to view.

As reshape is semantically equivalent to view except for working on a broader set of inputs, there should be no downside to changing this. The only difference is that it clones the underlying data in cases where .view would error out. I have confirmed that the output still looks as expected, but cannot confirm that no mutable use is made of the tensors anywhere.

Note that --fast is only marginally faster than the default.
2024-08-21 11:21:48 -04:00
Alex "mcmonkey" Goodwin
5e806f555d add a get models list api route (#4519)
* get models list api route

* remove copypasta
2024-08-21 02:04:42 -04:00
Robin Huang
f07e5bb522 Add GET /internal/files. (#4295)
* Create internal route table.

* List files.

* Add GET /internal/files.

Retrieves list of files in models, output, and user directories.

* Refactor file names.

* Use typing_extensions for Python 3.8

* Fix tests.

* Remove print statements.

* Update README.

* Add output and user to valid directory test.

* Add missing type hints.
2024-08-21 01:25:06 -04:00
comfyanonymous
03ec517afb Remove useless line, adjust windows default reserved vram. 2024-08-21 00:47:19 -04:00
Chenlei Hu
f257fc999f Add optional deprecated/experimental flag to node class (#4506)
* Add optional deprecated flag to node class

* nit

* Add experimental flag
2024-08-21 00:01:34 -04:00
Chenlei Hu
bb50e69839 Update frontend to 1.2.30 (#4513) 2024-08-21 00:00:49 -04:00
comfyanonymous
510f3438c1 Speed up fp8 matrix mult by using better code. 2024-08-20 22:53:26 -04:00
comfyanonymous
ea63b1c092 Simpletrainer lycoris format. 2024-08-20 12:05:13 -04:00
comfyanonymous
9953f22fce Add --fast argument to enable experimental optimizations.
Optimizations that might break things/lower quality will be put behind
this flag first and might be enabled by default in the future.

Currently the only optimization is float8_e4m3fn matrix multiplication on
4000/ADA series Nvidia cards or later. If you have one of these cards you
will see a speed boost when using fp8_e4m3fn flux for example.
2024-08-20 11:55:51 -04:00
comfyanonymous
d1a6bd6845 Support loading long clipl model with the CLIP loader node. 2024-08-20 10:46:36 -04:00
comfyanonymous
83dbac28eb Properly set if clip text pooled projection instead of using hack. 2024-08-20 10:46:36 -04:00
comfyanonymous
538cb068bc Make cast_to a nop if weight is already good. 2024-08-20 10:46:36 -04:00
comfyanonymous
1b3eee672c Fix potential issue with multi devices. 2024-08-20 10:46:36 -04:00
Chenlei Hu
5a69f84c3c Update README.md (Add shield badges) (#4490) 2024-08-19 18:25:20 -04:00
comfyanonymous
9eee470244 New load_text_encoder_state_dicts function.
Now you can load text encoders straight from a list of state dicts.
2024-08-19 17:36:35 -04:00
comfyanonymous
045377ea89 Add a --reserve-vram argument if you don't want comfy to use all of it.
--reserve-vram 1.0 for example will make ComfyUI try to keep 1GB vram free.

This can also be useful if workflows are failing because of OOM errors but
in that case please report it if --reserve-vram improves your situation.
2024-08-19 17:16:18 -04:00
comfyanonymous
4d341b78e8 Bug fixes. 2024-08-19 16:28:55 -04:00
comfyanonymous
6138f92084 Use better dtype for the lowvram lora system. 2024-08-19 15:35:25 -04:00
comfyanonymous
be0726c1ed Remove duplication. 2024-08-19 15:26:50 -04:00
comfyanonymous
766ae119a8 CheckpointSave node name. 2024-08-19 15:06:12 -04:00
Yoland Yan
fc90ceb6ba Update issue template config.yml to direct frontend issues to frontend repos (#4486)
* Update config.yml

* Typos
2024-08-19 13:41:30 -04:00
comfyanonymous
4506ddc86a Better subnormal fp8 stochastic rounding. Thanks Ashen. 2024-08-19 13:38:03 -04:00
comfyanonymous
20ace7c853 Code cleanup. 2024-08-19 12:48:59 -04:00
Chenlei Hu
b29b3b86c5 Update README to include frontend section (#4468)
* Update README to include frontend section

* nit
2024-08-19 07:12:32 -04:00
comfyanonymous
22ec02afc0 Handle subnormal numbers in float8 rounding. 2024-08-19 05:51:08 -04:00
comfyanonymous
39f114c44b Less broken non blocking? 2024-08-18 16:53:17 -04:00
comfyanonymous
6730f3e1a3 Disable non blocking.
It fixed some perf issues but caused other issues that need to be debugged.
2024-08-18 14:38:09 -04:00
comfyanonymous
73332160c8 Enable non blocking transfers in lowvram mode. 2024-08-18 10:29:33 -04:00
comfyanonymous
2622c55aff Automatically use RF variant of dpmpp_2s_ancestral if RF model. 2024-08-18 00:47:25 -04:00
Ashen
1beb348ee2 dpmpp_2s_ancestral_RF for rectified flow (Flux, SD3 and Auraflow). 2024-08-18 00:33:30 -04:00
bymyself
9aa39e743c Add new shortcuts to readme (#4442) 2024-08-17 23:52:56 -04:00
comfyanonymous
d31df04c8a Indentation. 2024-08-17 23:00:44 -04:00
Xrvk
e68763f40c Add Flux model support for InstantX style controlnet residuals (#4444)
* Add Flux model support for InstantX style controlnet residuals

* Refactor Flux controlnet residual step to a separate method

* Rollback minor change

* New format for applying controlnet residuals: input->double_blocks, output->single_blocks

* Adjust XLabs Flux controlnet to fit new syntax of applying Flux controlnet residuals

* Remove unnecessary import and minor style change
2024-08-17 22:58:23 -04:00
comfyanonymous
310ad09258 Add a ModelSave node. 2024-08-17 21:43:07 -04:00
comfyanonymous
4f7a3cb6fb unet -> diffusion_models. 2024-08-17 21:31:04 -04:00
comfyanonymous
bb222ceddb Fix loras having a weak effect when applied on fp8. 2024-08-17 15:20:17 -04:00
comfyanonymous
14af129c55 Improve execution UX.
Some branches with VAELoader -> VAEDecode -> Preview were being executed
last. With this change they will be executed earlier.
2024-08-17 11:37:21 -04:00
comfyanonymous
fca42836f2 Add model_options for text encoder. 2024-08-17 11:17:20 -04:00
comfyanonymous
858d51f91a Fix VAEDecode -> Preview not being executed first. 2024-08-17 04:08:54 -04:00
comfyanonymous
cd5017c1c9 calculate_weight function to use a different dtype. 2024-08-17 01:06:08 -04:00
comfyanonymous
83f343146a Fix potential lowvram issue. 2024-08-16 17:12:42 -04:00
Chenlei Hu
b021cf67c7 Update frontend to 1.2.26 (#4415) 2024-08-16 15:25:02 -04:00
Matthew Turnshek
1770fc77ed Implement support for taef1 latent previews (#4409)
* add taef1 handling to several places

* remove guess_latent_channels and add latent_channels info directly to flux model

* remove TODO

* fix numbers
2024-08-16 12:53:13 -04:00
comfyanonymous
05a9f3faa1 Log a warning when there's an issue with IS_CHANGED. 2024-08-16 08:50:17 -04:00
comfyanonymous
86c5970ac0 Fix custom nodes hooking the map_node_over_list and breaking things. 2024-08-16 08:40:31 -04:00
Chenlei Hu
bfc214f434 Use new TS frontend uncompressed (#4379)
* Swap frontend uncompressed

* Add uncompressed files
2024-08-15 16:50:25 -04:00
comfyanonymous
3f5939add6 Tell github not to count the web directory in language stats. 2024-08-15 13:48:56 -04:00
comfyanonymous
5960f946a9 Move a few files from comfy -> comfy_execution.
Python code in the comfy folder should not import things from outside it.
2024-08-15 11:21:14 -04:00
guill
5cfe38f41c Execution Model Inversion (#2666)
* Execution Model Inversion

This PR inverts the execution model -- from recursively calling nodes to
using a topological sort of the nodes. This change allows for
modification of the node graph during execution. This allows for two
major advantages:

    1. The implementation of lazy evaluation in nodes. For example, if a
    "Mix Images" node has a mix factor of exactly 0.0, the second image
    input doesn't even need to be evaluated (and visa-versa if the mix
    factor is 1.0).

    2. Dynamic expansion of nodes. This allows for the creation of dynamic
    "node groups". Specifically, custom nodes can return subgraphs that
    replace the original node in the graph. This is an incredibly
    powerful concept. Using this functionality, it was easy to
    implement:
        a. Components (a.k.a. node groups)
        b. Flow control (i.e. while loops) via tail recursion
        c. All-in-one nodes that replicate the WebUI functionality
        d. and more
    All of those were able to be implemented entirely via custom nodes,
    so those features are *not* a part of this PR. (There are some
    front-end changes that should occur before that functionality is
    made widely available, particularly around variant sockets.)

The custom nodes associated with this PR can be found at:
https://github.com/BadCafeCode/execution-inversion-demo-comfyui

Note that some of them require that variant socket types ("*") be
enabled.

* Allow `input_info` to be of type `None`

* Handle errors (like OOM) more gracefully

* Add a command-line argument to enable variants

This allows the use of nodes that have sockets of type '*' without
applying a patch to the code.

* Fix an overly aggressive assertion.

This could happen when attempting to evaluate `IS_CHANGED` for a node
during the creation of the cache (in order to create the cache key).

* Fix Pyright warnings

* Add execution model unit tests

* Fix issue with unused literals

Behavior should now match the master branch with regard to undeclared
inputs. Undeclared inputs that are socket connections will be used while
undeclared inputs that are literals will be ignored.

* Make custom VALIDATE_INPUTS skip normal validation

Additionally, if `VALIDATE_INPUTS` takes an argument named `input_types`,
that variable will be a dictionary of the socket type of all incoming
connections. If that argument exists, normal socket type validation will
not occur. This removes the last hurdle for enabling variant types
entirely from custom nodes, so I've removed that command-line option.

I've added appropriate unit tests for these changes.

* Fix example in unit test

This wouldn't have caused any issues in the unit test, but it would have
bugged the UI if someone copy+pasted it into their own node pack.

* Use fstrings instead of '%' formatting syntax

* Use custom exception types.

* Display an error for dependency cycles

Previously, dependency cycles that were created during node expansion
would cause the application to quit (due to an uncaught exception). Now,
we'll throw a proper error to the UI. We also make an attempt to 'blame'
the most relevant node in the UI.

* Add docs on when ExecutionBlocker should be used

* Remove unused functionality

* Rename ExecutionResult.SLEEPING to PENDING

* Remove superfluous function parameter

* Pass None for uneval inputs instead of default

This applies to `VALIDATE_INPUTS`, `check_lazy_status`, and lazy values
in evaluation functions.

* Add a test for mixed node expansion

This test ensures that a node that returns a combination of expanded
subgraphs and literal values functions correctly.

* Raise exception for bad get_node calls.

* Minor refactor of IsChangedCache.get

* Refactor `map_node_over_list` function

* Fix ui output for duplicated nodes

* Add documentation on `check_lazy_status`

* Add file for execution model unit tests

* Clean up Javascript code as per review

* Improve documentation

Converted some comments to docstrings as per review

* Add a new unit test for mixed lazy results

This test validates that when an output list is fed to a lazy node, the
node will properly evaluate previous nodes that are needed by any inputs
to the lazy node.

No code in the execution model has been changed. The test already
passes.

* Allow kwargs in VALIDATE_INPUTS functions

When kwargs are used, validation is skipped for all inputs as if they
had been mentioned explicitly.

* List cached nodes in `execution_cached` message

This was previously just bugged in this PR.
2024-08-15 11:21:11 -04:00
comfyanonymous
0f9c2a7822 Try to fix SDXL OOM issue on some configurations. 2024-08-14 23:08:54 -04:00
comfyanonymous
153d0a8142 Add a update/update_comfyui_stable.bat to the standalones. 2024-08-14 22:29:23 -04:00
Chenlei Hu
ab4dd19b91 Remove legacy ui test files (#4316) 2024-08-14 21:01:06 -04:00
comfyanonymous
f1d6cef71c Revert "Disable cuda malloc by default."
This reverts commit 50bf66e5c4.
2024-08-14 08:38:07 -04:00
comfyanonymous
33fb282d5c Fix issue. 2024-08-14 02:51:47 -04:00
comfyanonymous
50bf66e5c4 Disable cuda malloc by default. 2024-08-14 02:49:25 -04:00
pythongosssss
e60e19b175 Add support for simple tooltips (#3842)
* Add support for simple tooltips

* Fix overflow

* Add tooltips for nodes in the default workflow

* new line

* Prevent potential crash

* PR feedback

* Hide tooltip when clicking (e.g. combo widget)

* Refactor tooltips, add node level support

* Fix

* move

* Fix test (and undo last change)

* Fixed indent

* Fix dom widgets, dont show tooltip if not over canvas
2024-08-14 01:22:10 -04:00
comfyanonymous
a5af64d3ce Revert "Not sure if this actually changes anything but it can't hurt."
This reverts commit 34608de2e9.
2024-08-14 01:05:17 -04:00
Robin Huang
3e52e0364c Add model downloading endpoint. (#4248)
* Add model downloading endpoint.

* Move client session init to async function.

* Break up large function.

* Send "download_progress" as websocket event.

* Fixed

* Fixed.

* Use async mock.

* Move server set up to right before run call.

* Validate that model subdirectory cannot contain relative paths.

* Add download_model test checking for invalid paths.

* Remove DS_Store.

* Consolidate DownloadStatus and DownloadModelResult

* Add progress_interval as an optional parameter.

* Use tuple type from annotations.

* Use pydantic.

* Update comment.

* Revert "Use pydantic."

This reverts commit 7461e8eb00.

* Add new line.

* Add newline EOF.

* Validate model filename as well.

* Add comment to not reply on internal.

* Restrict downloading to safetensor files only.
2024-08-13 15:48:52 -04:00
comfyanonymous
34608de2e9 Not sure if this actually changes anything but it can't hurt. 2024-08-13 13:29:16 -04:00
comfyanonymous
39fb74c5bd Fix bug when model cannot be partially unloaded. 2024-08-13 03:57:55 -04:00
comfyanonymous
74e124f4d7 Fix some issues with TE being in lowvram mode. 2024-08-12 23:42:21 -04:00
comfyanonymous
a562c17e8a load_unet -> load_diffusion_model with a model_options argument. 2024-08-12 23:20:57 -04:00
comfyanonymous
5942c17d55 Order of operations matters. 2024-08-12 21:56:18 -04:00
comfyanonymous
c032b11e07 xlabs Flux controlnet implementation. (#4260)
* xlabs Flux controlnet.

* Fix not working on old python.

* Remove comment.
2024-08-12 21:22:22 -04:00
262 changed files with 142564 additions and 43887 deletions

View File

@@ -75,6 +75,25 @@ else:
print("pulling latest changes")
pull(repo)
if "--stable" in sys.argv:
def latest_tag(repo):
versions = []
for k in repo.references:
try:
prefix = "refs/tags/v"
if k.startswith(prefix):
version = list(map(int, k[len(prefix):].split(".")))
versions.append((version[0] * 10000000000 + version[1] * 100000 + version[2], k))
except:
pass
versions.sort()
if len(versions) > 0:
return versions[-1][1]
return None
latest_tag = latest_tag(repo)
if latest_tag is not None:
repo.checkout(latest_tag)
print("Done!")
self_update = True
@@ -115,3 +134,13 @@ if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
shutil.copy(repo_req_path, req_path)
except:
pass
stable_update_script = os.path.join(repo_path, ".ci/update_windows/update_comfyui_stable.bat")
stable_update_script_to = os.path.join(cur_path, "update_comfyui_stable.bat")
try:
if not file_size(stable_update_script_to) > 10:
shutil.copy(stable_update_script, stable_update_script_to)
except:
pass

View File

@@ -0,0 +1,8 @@
@echo off
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --stable
if exist update_new.py (
move /y update_new.py update.py
echo Running updater again since it got updated.
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update --stable
)
if "%~1"=="" pause

View File

@@ -14,7 +14,7 @@ run_cpu.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
RECOMMENDED WAY TO UPDATE:

View File

@@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
pause

2
.gitattributes vendored Normal file
View File

@@ -0,0 +1,2 @@
/web/assets/** linguist-generated
/web/** linguist-vendored

View File

@@ -1,5 +1,8 @@
blank_issues_enabled: true
contact_links:
- name: ComfyUI Frontend Issues
url: https://github.com/Comfy-Org/ComfyUI_frontend/issues
about: Issues related to the ComfyUI frontend (display issues, user interaction bugs), please go to the frontend repo to file the issue
- name: ComfyUI Matrix Space
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).

View File

@@ -23,7 +23,7 @@ jobs:
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, win]
runner_label: [self-hosted, Windows]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:

View File

@@ -12,17 +12,17 @@ on:
description: 'CUDA version'
required: true
type: string
default: "121"
default: "124"
python_minor:
description: 'Python minor version'
required: true
type: string
default: "11"
default: "12"
python_patch:
description: 'Python patch version'
required: true
type: string
default: "9"
default: "7"
jobs:

21
.github/workflows/stale-issues.yml vendored Normal file
View File

@@ -0,0 +1,21 @@
name: 'Close stale issues'
on:
schedule:
# Run daily at 430 am PT
- cron: '30 11 * * *'
permissions:
issues: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v9
with:
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
days-before-stale: 30
days-before-close: 7
stale-issue-label: 'Stale'
only-labels: 'User Support'
exempt-all-assignees: true
exempt-all-milestones: true

View File

@@ -32,7 +32,7 @@ jobs:
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, win]
runner_label: [self-hosted, Windows]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
@@ -55,7 +55,7 @@ jobs:
torch_version: ["nightly"]
include:
- os: windows
runner_label: [self-hosted, win]
runner_label: [self-hosted, Windows]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:

View File

@@ -1,10 +1,4 @@
# This is a temporary action during frontend TS migration.
# This file should be removed after TS migration is completed.
# The browser test is here to ensure TS repo is working the same way as the
# current JS code.
# If you are adding UI feature, please sync your changes to the TS repo:
# huchenlei/ComfyUI_frontend and update test expectation files accordingly.
name: Playwright Browser Tests CI
name: Test server launches without errors
on:
push:
@@ -21,15 +15,6 @@ jobs:
with:
repository: "comfyanonymous/ComfyUI"
path: "ComfyUI"
- name: Checkout ComfyUI_frontend
uses: actions/checkout@v4
with:
repository: "huchenlei/ComfyUI_frontend"
path: "ComfyUI_frontend"
ref: "fcc54d803e5b6a9b08a462a1d94899318c96dcbb"
- uses: actions/setup-node@v3
with:
node-version: lts/*
- uses: actions/setup-python@v4
with:
python-version: '3.8'
@@ -45,16 +30,6 @@ jobs:
python main.py --cpu 2>&1 | tee console_output.log &
wait-for-it --service 127.0.0.1:8188 -t 600
working-directory: ComfyUI
- name: Install ComfyUI_frontend dependencies
run: |
npm ci
working-directory: ComfyUI_frontend
- name: Install Playwright Browsers
run: npx playwright install --with-deps
working-directory: ComfyUI_frontend
- name: Run Playwright tests
run: npx playwright test
working-directory: ComfyUI_frontend
- name: Check for unhandled exceptions in server log
run: |
if grep -qE "Exception|Error" console_output.log; then
@@ -62,12 +37,6 @@ jobs:
exit 1
fi
working-directory: ComfyUI
- uses: actions/upload-artifact@v4
if: always()
with:
name: playwright-report
path: ComfyUI_frontend/playwright-report/
retention-days: 30
- uses: actions/upload-artifact@v4
if: always()
with:

View File

@@ -1,29 +1,29 @@
name: Tests CI
name: Unit Tests
on: [push, pull_request]
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.os }}
continue-on-error: true
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
node-version: 18
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Run Tests
run: |
npm ci
npm run test:generate
npm test -- --verbose
working-directory: ./tests-ui
- name: Run Unit Tests
run: |
pip install -r tests-unit/requirements.txt

View File

@@ -12,7 +12,7 @@ on:
description: 'extra dependencies'
required: false
type: string
default: "\"numpy<2\""
default: ""
cu:
description: 'cuda version'
required: true
@@ -23,13 +23,13 @@ on:
description: 'python minor version'
required: true
type: string
default: "11"
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "9"
default: "7"
# push:
# branches:
# - master

View File

@@ -67,6 +67,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
echo "call update_comfyui.bat nopause
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2

View File

@@ -13,13 +13,13 @@ on:
description: 'python minor version'
required: true
type: string
default: "11"
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "9"
default: "7"
# push:
# branches:
# - master

1
.gitignore vendored
View File

@@ -12,6 +12,7 @@ extra_model_paths.yaml
.vscode/
.idea/
venv/
.venv/
/web/extensions/*
!/web/extensions/logging.js.example
!/web/extensions/core/

View File

@@ -1,8 +1,35 @@
ComfyUI
=======
The most powerful and modular stable diffusion GUI and backend.
-----------
<div align="center">
# ComfyUI
**The most powerful and modular diffusion model GUI and backend.**
[![Website][website-shield]][website-url]
[![Dynamic JSON Badge][discord-shield]][discord-url]
[![Matrix][matrix-shield]][matrix-url]
<br>
[![][github-release-shield]][github-release-link]
[![][github-release-date-shield]][github-release-link]
[![][github-downloads-shield]][github-downloads-link]
[![][github-downloads-latest-shield]][github-downloads-link]
[matrix-shield]: https://img.shields.io/badge/Matrix-000000?style=flat&logo=matrix&logoColor=white
[matrix-url]: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
[website-shield]: https://img.shields.io/badge/ComfyOrg-4285F4?style=flat
[website-url]: https://www.comfy.org/
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
[discord-url]: https://www.comfy.org/discord
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
[github-release-date-shield]: https://img.shields.io/github/release-date/comfyanonymous/ComfyUI?style=flat
[github-downloads-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/total?style=flat
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
![ComfyUI Screenshot](comfyui_screenshot.png)
</div>
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
@@ -48,6 +75,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + Alt + Enter | Cancel current generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo |
| Ctrl + S | Save workflow |
| Ctrl + O | Load workflow |
@@ -66,10 +94,14 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Alt + `+` | Canvas Zoom in |
| Alt + `-` | Canvas Zoom out |
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
| P | Pin/Unpin selected nodes |
| Ctrl + G | Group selected nodes |
| Q | Toggle visibility of the queue |
| H | Toggle visibility of history |
| R | Refresh graph |
| Double-Click LMB | Open node quick search palette |
| Shift + Drag | Move multiple wires at once |
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
Ctrl can also be replaced with Cmd instead for macOS users
@@ -95,6 +127,8 @@ To run it on services like paperspace, kaggle or colab you can use my [Jupyter N
## Manual Install (Windows, Linux)
Note that some dependencies do not yet support python 3.13 so using 3.12 is recommended.
Git clone this repo.
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
@@ -105,17 +139,17 @@ Put your VAE in: models/vae
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2```
### NVIDIA
Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
This is the command to install pytorch nightly instead which might have performance improvements:
@@ -200,7 +234,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
Use ```--preview-method auto``` to enable previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
## How to use TLS/SSL?
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
@@ -216,6 +250,47 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
See also: [https://www.comfy.org/](https://www.comfy.org/)
## Frontend Development
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
### Reporting Issues and Requesting Features
For any bugs, issues, or feature requests related to the frontend, please use the [ComfyUI Frontend repository](https://github.com/Comfy-Org/ComfyUI_frontend). This will help us manage and address frontend-specific concerns more efficiently.
### Using the Latest Frontend
The new frontend is now the default for ComfyUI. However, please note:
1. The frontend in the main ComfyUI repository is updated weekly.
2. Daily releases are available in the separate frontend repository.
To use the most up-to-date frontend version:
1. For the latest daily release, launch ComfyUI with this command line argument:
```
--front-end-version Comfy-Org/ComfyUI_frontend@latest
```
2. For a specific version, replace `latest` with the desired version number:
```
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
```
This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
### Accessing the Legacy Frontend
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
```
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
```
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
# QA
### Which GPU should I buy for this?

0
api_server/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,3 @@
# ComfyUI Internal Routes
All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.

View File

View File

@@ -0,0 +1,51 @@
from aiohttp import web
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
from api_server.services.file_service import FileService
import app.logger
class InternalRoutes:
'''
The top level web router for internal routes: /internal/*
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
Check README.md for more information.
'''
def __init__(self):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
self.file_service = FileService({
"models": models_dir,
"user": user_directory,
"output": output_directory
})
def setup_routes(self):
@self.routes.get('/files')
async def list_files(request):
directory_key = request.query.get('directory', '')
try:
file_list = self.file_service.list_files(directory_key)
return web.json_response({"files": file_list})
except ValueError as e:
return web.json_response({"error": str(e)}, status=400)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
@self.routes.get('/logs')
async def get_logs(request):
return web.json_response(app.logger.get_logs())
@self.routes.get('/folder_paths')
async def get_folder_paths(request):
response = {}
for key in folder_names_and_paths:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)
def get_app(self):
if self._app is None:
self._app = web.Application()
self.setup_routes()
self._app.add_routes(self.routes)
return self._app

View File

View File

@@ -0,0 +1,13 @@
from typing import Dict, List, Optional
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
class FileService:
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
self.allowed_directories: Dict[str, str] = allowed_directories
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
def list_files(self, directory_key: str) -> List[FileSystemItem]:
if directory_key not in self.allowed_directories:
raise ValueError("Invalid directory key")
directory_path: str = self.allowed_directories[directory_key]
return self.file_system_ops.walk_directory(directory_path)

View File

@@ -0,0 +1,42 @@
import os
from typing import List, Union, TypedDict, Literal
from typing_extensions import TypeGuard
class FileInfo(TypedDict):
name: str
path: str
type: Literal["file"]
size: int
class DirectoryInfo(TypedDict):
name: str
path: str
type: Literal["directory"]
FileSystemItem = Union[FileInfo, DirectoryInfo]
def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
return item["type"] == "file"
class FileSystemOperations:
@staticmethod
def walk_directory(directory: str) -> List[FileSystemItem]:
file_list: List[FileSystemItem] = []
for root, dirs, files in os.walk(directory):
for name in files:
file_path = os.path.join(root, name)
relative_path = os.path.relpath(file_path, directory)
file_list.append({
"name": name,
"path": relative_path,
"type": "file",
"size": os.path.getsize(file_path)
})
for name in dirs:
dir_path = os.path.join(root, name)
relative_path = os.path.relpath(dir_path, directory)
file_list.append({
"name": name,
"path": relative_path,
"type": "directory"
})
return file_list

View File

@@ -8,7 +8,7 @@ import zipfile
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TypedDict
from typing import TypedDict, Optional
import requests
from typing_extensions import NotRequired
@@ -132,12 +132,13 @@ class FrontendManager:
return match_result.group(1), match_result.group(2), match_result.group(3)
@classmethod
def init_frontend_unsafe(cls, version_string: str) -> str:
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
"""
Initializes the frontend for the specified version.
Args:
version_string (str): The version string.
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
Returns:
str: The path to the initialized frontend.
@@ -150,7 +151,16 @@ class FrontendManager:
return cls.DEFAULT_FRONTEND_PATH
repo_owner, repo_name, version = cls.parse_version_string(version_string)
provider = FrontEndProvider(repo_owner, repo_name)
if version.startswith("v"):
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
if os.path.exists(expected_path):
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
return expected_path
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
provider = provider or FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version)
semantic_version = release["tag_name"].lstrip("v")
@@ -158,15 +168,21 @@ class FrontendManager:
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
)
if not os.path.exists(web_root):
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
try:
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
finally:
# Clean up the directory if it is empty, i.e. the download failed
if not os.listdir(web_root):
os.rmdir(web_root)
return web_root
@classmethod

31
app/logger.py Normal file
View File

@@ -0,0 +1,31 @@
import logging
from logging.handlers import MemoryHandler
from collections import deque
logs = None
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
def get_logs():
return "\n".join([formatter.format(x) for x in logs])
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
global logs
if logs:
return
# Setup default global logger
logger = logging.getLogger()
logger.setLevel(log_level)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(stream_handler)
# Create a memory handler with a deque as its buffer
logs = deque(maxlen=capacity)
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
memory_handler.buffer = logs
memory_handler.setFormatter(formatter)
logger.addHandler(memory_handler)

View File

@@ -5,17 +5,17 @@ import uuid
import glob
import shutil
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
from folder_paths import user_directory
import folder_paths
from .app_settings import AppSettings
default_user = "default"
users_file = os.path.join(user_directory, "users.json")
class UserManager():
def __init__(self):
global user_directory
user_directory = folder_paths.get_user_directory()
self.settings = AppSettings(self)
if not os.path.exists(user_directory):
@@ -25,14 +25,17 @@ class UserManager():
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user:
if os.path.isfile(users_file):
with open(users_file) as f:
if os.path.isfile(self.get_users_file()):
with open(self.get_users_file()) as f:
self.users = json.load(f)
else:
self.users = {}
else:
self.users = {"default": "default"}
def get_users_file(self):
return os.path.join(folder_paths.get_user_directory(), "users.json")
def get_request_user_id(self, request):
user = "default"
if args.multi_user and "comfy-user" in request.headers:
@@ -44,7 +47,7 @@ class UserManager():
return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
global user_directory
user_directory = folder_paths.get_user_directory()
if type == "userdata":
root_dir = user_directory
@@ -59,6 +62,10 @@ class UserManager():
return None
if file is not None:
# Check if filename is url encoded
if "%" in file:
file = parse.unquote(file)
# prevent leaving /{type}/{user}
path = os.path.abspath(os.path.join(user_root, file))
if os.path.commonpath((user_root, path)) != user_root:
@@ -80,8 +87,7 @@ class UserManager():
self.users[user_id] = name
global users_file
with open(users_file, "w") as f:
with open(self.get_users_file(), "w") as f:
json.dump(self.users, f)
return user_id
@@ -112,25 +118,69 @@ class UserManager():
@routes.get("/userdata")
async def listuserdata(request):
"""
List user data files in a specified directory.
This endpoint allows listing files in a user's data directory, with options for recursion,
full file information, and path splitting.
Query Parameters:
- dir (required): The directory to list files from.
- recurse (optional): If "true", recursively list files in subdirectories.
- full_info (optional): If "true", return detailed file information (path, size, modified time).
- split (optional): If "true", split file paths into components (only applies when full_info is false).
Returns:
- 400: If 'dir' parameter is missing.
- 403: If the requested path is not allowed.
- 404: If the requested directory does not exist.
- 200: JSON response with the list of files or file information.
The response format depends on the query parameters:
- Default: List of relative file paths.
- full_info=true: List of dictionaries with file details.
- split=true (and full_info=false): List of lists, each containing path components.
"""
directory = request.rel_url.query.get('dir', '')
if not directory:
return web.Response(status=400)
return web.Response(status=400, text="Directory not provided")
path = self.get_request_user_filepath(request, directory)
if not path:
return web.Response(status=403)
return web.Response(status=403, text="Invalid directory")
if not os.path.exists(path):
return web.Response(status=404)
return web.Response(status=404, text="Directory not found")
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
results = glob.glob(os.path.join(
glob.escape(path), '**/*'), recursive=recurse)
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
# Use different patterns based on whether we're recursing or not
if recurse:
pattern = os.path.join(glob.escape(path), '**', '*')
else:
pattern = os.path.join(glob.escape(path), '*')
results = glob.glob(pattern, recursive=recurse)
if full_info:
results = [
{
'path': os.path.relpath(x, path).replace(os.sep, '/'),
'size': os.path.getsize(x),
'modified': os.path.getmtime(x)
} for x in results if os.path.isfile(x)
]
else:
results = [
os.path.relpath(x, path).replace(os.sep, '/')
for x in results
if os.path.isfile(x)
]
split_path = request.rel_url.query.get('split', '').lower() == "true"
if split_path:
results = [[x] + x.split(os.sep) for x in results]
if split_path and not full_info:
results = [[x] + x.split('/') for x in results]
return web.json_response(results)
@@ -138,14 +188,14 @@ class UserManager():
file = request.match_info.get(param, None)
if not file:
return web.Response(status=400)
path = self.get_request_user_filepath(request, file)
if not path:
return web.Response(status=403)
if check_exists and not os.path.exists(path):
return web.Response(status=404)
return path
@routes.get("/userdata/{file}")
@@ -153,7 +203,7 @@ class UserManager():
path = get_user_data_path(request, check_exists=True)
if not isinstance(path, str):
return path
return web.FileResponse(path)
@routes.post("/userdata/{file}")
@@ -161,7 +211,7 @@ class UserManager():
path = get_user_data_path(request)
if not isinstance(path, str):
return path
overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(path):
return web.Response(status=409)
@@ -170,7 +220,7 @@ class UserManager():
with open(path, "wb") as f:
f.write(body)
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
return web.json_response(resp)
@@ -181,7 +231,7 @@ class UserManager():
return path
os.remove(path)
return web.Response(status=204)
@routes.post("/userdata/{file}/move/{dest}")
@@ -189,17 +239,17 @@ class UserManager():
source = get_user_data_path(request, check_exists=True)
if not isinstance(source, str):
return source
dest = get_user_data_path(request, check_exists=False, param="dest")
if not isinstance(source, str):
return dest
overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(dest):
return web.Response(status=409)
print(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest)
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
return web.json_response(resp)

View File

@@ -6,6 +6,7 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
def __init__(
self,
num_blocks = None,
control_latent_channels = None,
dtype = None,
device = None,
operations = None,
@@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
for _ in range(len(self.joint_blocks)):
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
if control_latent_channels is None:
control_latent_channels = self.in_channels
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
None,
self.patch_size,
self.in_channels,
control_latent_channels,
self.hidden_size,
bias=True,
strict_img_size=False,

View File

@@ -36,7 +36,7 @@ class EnumAction(argparse.Action):
parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
@@ -92,6 +92,12 @@ class LatentPreviewMethod(enum.Enum):
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
@@ -112,10 +118,14 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@@ -126,7 +136,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
@@ -161,6 +171,8 @@ parser.add_argument(
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
)
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
if comfy.options.args_parsing:
args = parser.parse_args()
else:
@@ -171,10 +183,3 @@ if args.windows_standalone_build:
if args.disable_auto_launch:
args.auto_launch = False
import logging
logging_level = logging.INFO
if args.verbose:
logging_level = logging.DEBUG
logging.basicConfig(format="%(message)s", level=logging_level)

View File

@@ -88,10 +88,11 @@ class CLIPTextModel_(torch.nn.Module):
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
num_positions = config_dict["max_position_embeddings"]
self.eos_token_id = config_dict["eos_token_id"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations)
self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
@@ -123,7 +124,6 @@ class CLIPTextModel(torch.nn.Module):
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
embed_dim = config_dict["hidden_size"]
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
def get_input_embeddings(self):

View File

@@ -109,8 +109,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
keys = list(sd.keys())
for k in keys:
if k not in u:
t = sd.pop(k)
del t
sd.pop(k)
return clip
def load(ckpt_path):

View File

@@ -34,6 +34,8 @@ import comfy.t2i_adapter.adapter
import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
@@ -58,7 +60,7 @@ class StrengthType(Enum):
LINEAR_UP = 2
class ControlBase:
def __init__(self, device=None):
def __init__(self):
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
@@ -70,20 +72,24 @@ class ControlBase:
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
if device is None:
device = comfy.model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
self.extra_concat_orig = []
self.extra_concat = None
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range
if self.latent_format is not None:
if vae is None:
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
self.vae = vae
self.extra_concat_orig = extra_concat.copy()
if self.concat_mask and len(self.extra_concat_orig) == 0:
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
return self
def pre_run(self, model, percent_to_timestep_function):
@@ -98,9 +104,9 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
def get_models(self):
@@ -121,6 +127,8 @@ class ControlBase:
c.vae = self.vae
c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type
c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy()
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
@@ -146,7 +154,7 @@ class ControlBase:
elif self.strength_type == StrengthType.LINEAR_UP:
x *= (self.strength ** float(len(control_output) - i))
if x.dtype != output_dtype:
if output_dtype is not None and x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].append(x)
@@ -173,8 +181,8 @@ class ControlBase:
class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
super().__init__(device)
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
super().__init__()
self.control_model = control_model
self.load_device = load_device
if control_model is not None:
@@ -187,6 +195,7 @@ class ControlNet(ControlBase):
self.latent_format = latent_format
self.extra_conds += extra_conds
self.strength_type = strength_type
self.concat_mask = concat_mask
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
@@ -204,7 +213,6 @@ class ControlNet(ControlBase):
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
@@ -212,6 +220,9 @@ class ControlNet(ControlBase):
compression_ratio = self.compression_ratio
if self.vae is not None:
compression_ratio *= self.vae.downscale_ratio
else:
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
@@ -219,7 +230,15 @@ class ControlNet(ControlBase):
comfy.model_management.load_models_gpu(loaded_models)
if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if len(self.extra_concat_orig) > 0:
to_concat = []
for c in self.extra_concat_orig:
c = c.to(self.cond_hint.device)
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
@@ -234,7 +253,7 @@ class ControlNet(ControlBase):
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
return self.control_merge(control, control_prev, output_dtype)
return self.control_merge(control, control_prev, output_dtype=None)
def copy(self):
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
@@ -318,8 +337,8 @@ class ControlLoraOps:
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None):
ControlBase.__init__(self, device)
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
ControlBase.__init__(self)
self.control_weights = control_weights
self.global_average_pooling = global_average_pooling
self.extra_conds += ["y"]
@@ -375,21 +394,28 @@ class ControlLora(ControlNet):
def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def controlnet_config(sd):
def controlnet_config(sd, model_options={}):
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
supported_inference_dtypes = model_config.supported_inference_dtypes
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(sd)
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
controlnet_config = model_config.unet_config
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
offload_device = comfy.model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)
@@ -401,25 +427,31 @@ def controlnet_load_state_dict(control_model, sd):
logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model
def load_controlnet_mmdit(sd):
def load_controlnet_mmdit(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
concat_mask = False
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
if control_latent_channels == 17: #inpaint controlnet
concat_mask = True
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data)
latent_format = comfy.latent_formats.SDXL()
@@ -427,13 +459,49 @@ def load_controlnet_hunyuandit(controlnet_data):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control
def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet_flux_instantx(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
for k in sd:
new_sd[k] = sd[k]
num_union_modes = 0
union_cnet = "controlnet_mode_embedder.weight"
if union_cnet in new_sd:
num_union_modes = new_sd[union_cnet].shape[0]
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
concat_mask = False
if control_latent_channels == 17:
concat_mask = True
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.Flux()
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
controlnet_data = state_dict
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data)
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)
return ControlLora(controlnet_data, model_options=model_options)
controlnet_config = None
supported_inference_dtypes = None
@@ -488,8 +556,15 @@ def load_controlnet(ckpt_path, model=None):
if len(leftover_keys) > 0:
logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
return load_controlnet_mmdit(controlnet_data)
elif "controlnet_blocks.0.weight" in controlnet_data:
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
pth_key = 'control_model.zero_convs.0.0.weight'
pth = False
@@ -501,26 +576,38 @@ def load_controlnet(ckpt_path, model=None):
elif key in controlnet_data:
prefix = ""
else:
net = load_t2i_adapter(controlnet_data)
net = load_t2i_adapter(controlnet_data, model_options=model_options)
if net is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
logging.error("error could not detect control model type.")
return net
if controlnet_config is None:
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
supported_inference_dtypes = model_config.supported_inference_dtypes
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
controlnet_config = model_config.unet_config
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
if supported_inference_dtypes is None:
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
if supported_inference_dtypes is None:
unet_dtype = comfy.model_management.unet_dtype()
else:
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
controlnet_config["operations"] = operations
controlnet_config["dtype"] = unet_dtype
controlnet_config["device"] = comfy.model_management.unet_offload_device()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
@@ -554,22 +641,32 @@ def load_controlnet(ckpt_path, model=None):
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
global_average_pooling = model_options.get("global_average_pooling", False)
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
def load_controlnet(ckpt_path, model=None, model_options={}):
if "global_average_pooling" not in model_options:
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
model_options["global_average_pooling"] = True
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
if cnet is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return cnet
class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
super().__init__(device)
super().__init__()
self.t2i_model = t2i_model
self.channels_in = channels_in
self.control_input = None
self.compression_ratio = compression_ratio
self.upscale_algorithm = upscale_algorithm
if device is None:
device = comfy.model_management.get_torch_device()
self.device = device
def scale_image_to(self, width, height):
unshuffle_amount = self.t2i_model.unshuffle_amount
@@ -617,7 +714,7 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
def load_t2i_adapter(t2i_data):
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8
upscale_algorithm = 'nearest-exact'

View File

@@ -22,7 +22,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
if text_encoder2_path is not None:
text_encoder_paths.append(text_encoder2_path)
unet = comfy.sd.load_unet(unet_path)
unet = comfy.sd.load_diffusion_model(unet_path)
clip = None
if output_clip:

67
comfy/float.py Normal file
View File

@@ -0,0 +1,67 @@
import torch
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
mantissa_scaled = torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
#Not 100% sure about this
def manual_stochastic_round_to_float8(x, dtype, generator=None):
if dtype == torch.float8_e4m3fn:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
elif dtype == torch.float8_e5m2:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
else:
raise ValueError("Unsupported dtype")
x = x.half()
sign = torch.sign(x)
abs_x = x.abs()
sign = torch.where(abs_x == 0, 0, sign)
# Combine exponent calculation and clamping
exponent = torch.clamp(
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
0, 2**EXPONENT_BITS - 1
)
# Combine mantissa calculation and rounding
normal_mask = ~(exponent == 0)
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
sign *= torch.where(
normal_mask,
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
)
inf = torch.finfo(dtype)
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
return sign
def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float32:
return value.to(dtype=torch.float32)
if dtype == torch.float16:
return value.to(dtype=torch.float16)
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
return output
return value.to(dtype=dtype)

View File

@@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm
from . import utils
from . import deis
import comfy.model_patcher
import comfy.model_sampling
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
@@ -43,6 +44,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
return append_zero(sigmas)
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
epsilon = 1e-5 # avoid log(0)
x = torch.linspace(0, 1, n, device=device)
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
sigmas = clamp(torch.exp(lmb))
return sigmas
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / utils.append_dims(sigma, x.ndim)
@@ -152,6 +164,8 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@@ -169,6 +183,29 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
# Euler method
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
if sigmas[i + 1] > 0 and eta > 0:
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x
@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
@@ -509,6 +546,9 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac
@torch.no_grad()
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@@ -541,6 +581,55 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
# logged_x = x.unsqueeze(0)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++(2S)
if sigmas[i] == 1.0:
sigma_s = 0.9999
else:
t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down)
r = 1 / 2
h = t_down - t_i
s = t_i + r * h
sigma_s = sigma_fn(s)
# sigma_s = sigmas[i+1]
sigma_s_i_ratio = sigma_s / sigmas[i]
u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised
D_i = model(u, sigma_s * s_in, **extra_args)
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
# Noise addition
if sigmas[i + 1] > 0 and eta > 0:
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
return x
@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""
@@ -1016,7 +1105,6 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
d = to_d(x, sigma_hat, temp[0])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = denoised + d * sigmas[i + 1]
return x
@@ -1043,8 +1131,81 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], temp[0])
# Euler method
dt = sigma_down - sigmas[i]
x = denoised + d * sigma_down
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigma_down == 0:
# Euler method
d = to_d(x, sigmas[i], temp[0])
x = denoised + d * sigma_down
else:
# DPM-Solver++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
r = 1 / 2
h = t_next - t
s = t + r * h
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
# Noise addition
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
t_fn = lambda sigma: sigma.log().neg()
old_uncond_denoised = None
uncond_denoised = None
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_uncond_denoised is None or sigmas[i + 1] == 0:
denoised_mix = -torch.exp(-h) * uncond_denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
x = denoised + denoised_mix + torch.exp(-h) * x
old_uncond_denoised = uncond_denoised
return x

View File

@@ -4,6 +4,7 @@ class LatentFormat:
scale_factor = 1.0
latent_channels = 4
latent_rgb_factors = None
latent_rgb_factors_bias = None
taesd_decoder_name = None
def process_in(self, latent):
@@ -30,11 +31,13 @@ class SDXL(LatentFormat):
def __init__(self):
self.latent_rgb_factors = [
# R G B
[ 0.3920, 0.4054, 0.4549],
[-0.2634, -0.0196, 0.0653],
[ 0.0568, 0.1687, -0.0755],
[-0.3112, -0.2359, -0.2076]
[ 0.3651, 0.4232, 0.4341],
[-0.2533, -0.0042, 0.1068],
[ 0.1076, 0.1111, -0.0362],
[-0.3165, -0.2492, -0.2188]
]
self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011]
self.taesd_decoder_name = "taesdxl_decoder"
class SDXL_Playground_2_5(LatentFormat):
@@ -112,23 +115,24 @@ class SD3(LatentFormat):
self.scale_factor = 1.5305
self.shift_factor = 0.0609
self.latent_rgb_factors = [
[-0.0645, 0.0177, 0.1052],
[ 0.0028, 0.0312, 0.0650],
[ 0.1848, 0.0762, 0.0360],
[ 0.0944, 0.0360, 0.0889],
[ 0.0897, 0.0506, -0.0364],
[-0.0020, 0.1203, 0.0284],
[ 0.0855, 0.0118, 0.0283],
[-0.0539, 0.0658, 0.1047],
[-0.0057, 0.0116, 0.0700],
[-0.0412, 0.0281, -0.0039],
[ 0.1106, 0.1171, 0.1220],
[-0.0248, 0.0682, -0.0481],
[ 0.0815, 0.0846, 0.1207],
[-0.0120, -0.0055, -0.0867],
[-0.0749, -0.0634, -0.0456],
[-0.1418, -0.1457, -0.1259]
[-0.0922, -0.0175, 0.0749],
[ 0.0311, 0.0633, 0.0954],
[ 0.1994, 0.0927, 0.0458],
[ 0.0856, 0.0339, 0.0902],
[ 0.0587, 0.0272, -0.0496],
[-0.0006, 0.1104, 0.0309],
[ 0.0978, 0.0306, 0.0427],
[-0.0042, 0.1038, 0.1358],
[-0.0194, 0.0020, 0.0669],
[-0.0488, 0.0130, -0.0268],
[ 0.0922, 0.0988, 0.0951],
[-0.0278, 0.0524, -0.0542],
[ 0.0332, 0.0456, 0.0895],
[-0.0069, -0.0030, -0.0810],
[-0.0596, -0.0465, -0.0293],
[-0.1448, -0.1463, -0.1189]
]
self.latent_rgb_factors_bias = [0.2394, 0.2135, 0.1925]
self.taesd_decoder_name = "taesd3_decoder"
def process_in(self, latent):
@@ -141,30 +145,60 @@ class StableAudio1(LatentFormat):
latent_channels = 64
class Flux(SD3):
latent_channels = 16
def __init__(self):
self.scale_factor = 0.3611
self.shift_factor = 0.1159
self.latent_rgb_factors =[
[-0.0404, 0.0159, 0.0609],
[ 0.0043, 0.0298, 0.0850],
[ 0.0328, -0.0749, -0.0503],
[-0.0245, 0.0085, 0.0549],
[ 0.0966, 0.0894, 0.0530],
[ 0.0035, 0.0399, 0.0123],
[ 0.0583, 0.1184, 0.1262],
[-0.0191, -0.0206, -0.0306],
[-0.0324, 0.0055, 0.1001],
[ 0.0955, 0.0659, -0.0545],
[-0.0504, 0.0231, -0.0013],
[ 0.0500, -0.0008, -0.0088],
[ 0.0982, 0.0941, 0.0976],
[-0.1233, -0.0280, -0.0897],
[-0.0005, -0.0530, -0.0020],
[-0.1273, -0.0932, -0.0680]
[-0.0346, 0.0244, 0.0681],
[ 0.0034, 0.0210, 0.0687],
[ 0.0275, -0.0668, -0.0433],
[-0.0174, 0.0160, 0.0617],
[ 0.0859, 0.0721, 0.0329],
[ 0.0004, 0.0383, 0.0115],
[ 0.0405, 0.0861, 0.0915],
[-0.0236, -0.0185, -0.0259],
[-0.0245, 0.0250, 0.1180],
[ 0.1008, 0.0755, -0.0421],
[-0.0515, 0.0201, 0.0011],
[ 0.0428, -0.0012, -0.0036],
[ 0.0817, 0.0765, 0.0749],
[-0.1264, -0.0522, -0.1103],
[-0.0280, -0.0881, -0.0499],
[-0.1262, -0.0982, -0.0778]
]
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.taesd_decoder_name = "taef1_decoder"
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
class Mochi(LatentFormat):
latent_channels = 12
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([-0.06730895953510081, -0.038011381506090416, -0.07477820912866141,
-0.05565264470995561, 0.012767231469026969, -0.04703542746246419,
0.043896967884726704, -0.09346305707025976, -0.09918314763016893,
-0.008729793427399178, -0.011931556316503654, -0.0321993391887285]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([0.9263795028493863, 0.9248894543193766, 0.9393059390890617,
0.959253732819592, 0.8244560132752793, 0.917259975397747,
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
self.latent_rgb_factors = None #TODO
self.taesd_decoder_name = None #TODO
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std
def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean

View File

@@ -1,4 +1,5 @@
import torch
import comfy.ops
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
@@ -6,3 +7,21 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
try:
rms_norm_torch = torch.nn.functional.rms_norm
except:
rms_norm_torch = None
def rms_norm(x, weight=None, eps=1e-6):
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
else:
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
if weight is None:
return r
else:
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)

View File

@@ -0,0 +1,205 @@
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
#modified to support different types of flux controlnets
import torch
import math
from torch import Tensor, nn
from einops import rearrange, repeat
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from .model import Flux
import comfy.ldm.common_dit
class MistolineCondDownsamplBlock(nn.Module):
def __init__(self, dtype=None, device=None, operations=None):
super().__init__()
self.encoder = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward(self, x):
return self.encoder(x)
class MistolineControlnetBlock(nn.Module):
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.linear(x))
class ControlNetFlux(Flux):
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.main_model_double = 19
self.main_model_single = 38
self.mistoline = mistoline
# add ControlNet blocks
if self.mistoline:
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth):
self.controlnet_blocks.append(control_block())
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(self.params.depth_single_blocks):
self.controlnet_single_blocks.append(control_block())
self.num_union_modes = num_union_modes
self.controlnet_mode_embedder = None
if self.num_union_modes > 0:
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
self.gradient_checkpointing = False
self.latent_input = latent_input
if control_latent_channels is None:
control_latent_channels = self.in_channels
else:
control_latent_channels *= 2 * 2 #patch size
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
if not self.latent_input:
if self.mistoline:
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
else:
self.input_hint_block = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control_type: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
txt = torch.cat([control_cond, txt], dim=1)
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
controlnet_double = ()
for i in range(len(self.double_blocks)):
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
img = torch.cat((txt, img), 1)
controlnet_single = ()
for i in range(len(self.single_blocks)):
img = self.single_blocks[i](img, vec=vec, pe=pe)
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
repeat = math.ceil(self.main_model_double / len(controlnet_double))
if self.latent_input:
out_input = ()
for x in controlnet_double:
out_input += (x,) * repeat
else:
out_input = (controlnet_double * repeat)
out = {"input": out_input[:self.main_model_double]}
if len(controlnet_single) > 0:
repeat = math.ceil(self.main_model_single / len(controlnet_single))
out_output = ()
if self.latent_input:
for x in controlnet_single:
out_output += (x,) * repeat
else:
out_output = (controlnet_single * repeat)
out["output"] = out_output[:self.main_model_single]
return out
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
patch_size = 2
if self.latent_input:
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
elif self.mistoline:
hint = hint * 2.0 - 1.0
hint = self.input_cond_block(hint)
else:
hint = hint * 2.0 - 1.0
hint = self.input_hint_block(hint)
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))

View File

@@ -6,6 +6,7 @@ from torch import Tensor, nn
from .math import attention, rope
import comfy.ops
import comfy.ldm.common_dit
class EmbedND(nn.Module):
@@ -63,10 +64,7 @@ class RMSNorm(torch.nn.Module):
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
class QKNorm(torch.nn.Module):
@@ -170,15 +168,15 @@ class DoubleStreamBlock(nn.Module):
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img += img_mod1.gate * self.img_attn.proj(img_attn)
img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = txt.clip(-65504, 65504)
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
@@ -233,7 +231,7 @@ class SingleStreamBlock(nn.Module):
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
if x.dtype == torch.float16:
x = x.clip(-65504, 65504)
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x

View File

@@ -38,7 +38,7 @@ class Flux(nn.Module):
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
params = FluxParams(**kwargs)
@@ -83,7 +83,8 @@ class Flux(nn.Module):
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
@@ -94,6 +95,7 @@ class Flux(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -106,24 +108,40 @@ class Flux(nn.Module):
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y)
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
for i, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
for i, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance, **kwargs):
def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
@@ -133,10 +151,10 @@ class Flux(nn.Module):
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

View File

@@ -0,0 +1,541 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
# from flash_attn import flash_attn_varlen_qkvpacked_func
from comfy.ldm.modules.attention import optimized_attention
from .layers import (
FeedForward,
PatchEmbed,
RMSNorm,
TimestepEmbedder,
)
from .rope_mixed import (
compute_mixed_rotation,
create_position_matrix,
)
from .temporal_rope import apply_rotary_emb_qk_real
from .utils import (
AttentionPool,
modulate,
)
import comfy.ldm.common_dit
import comfy.ops
def modulated_rmsnorm(x, scale, eps=1e-6):
# Normalize and modulate
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
x_modulated = x_normed * (1 + scale.unsqueeze(1))
return x_modulated
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
# Apply tanh to gate
tanh_gate = torch.tanh(gate).unsqueeze(1)
# Normalize and apply gated scaling
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
# Apply residual connection
output = x + x_normed
return output
class AsymmetricAttention(nn.Module):
def __init__(
self,
dim_x: int,
dim_y: int,
num_heads: int = 8,
qkv_bias: bool = True,
qk_norm: bool = False,
attn_drop: float = 0.0,
update_y: bool = True,
out_bias: bool = True,
attend_to_padding: bool = False,
softmax_scale: Optional[float] = None,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
):
super().__init__()
self.dim_x = dim_x
self.dim_y = dim_y
self.num_heads = num_heads
self.head_dim = dim_x // num_heads
self.attn_drop = attn_drop
self.update_y = update_y
self.attend_to_padding = attend_to_padding
self.softmax_scale = softmax_scale
if dim_x % num_heads != 0:
raise ValueError(
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
)
# Input layers.
self.qkv_bias = qkv_bias
self.qkv_x = operations.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
# Project text features to match visual features (dim_y -> dim_x)
self.qkv_y = operations.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
# Query and key normalization for stability.
assert qk_norm
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
# Output layers. y features go back down from dim_x -> dim_y.
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
self.proj_y = (
operations.Linear(dim_x, dim_y, bias=out_bias, device=device, dtype=dtype)
if update_y
else nn.Identity()
)
def forward(
self,
x: torch.Tensor, # (B, N, dim_x)
y: torch.Tensor, # (B, L, dim_y)
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
crop_y,
**rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]:
rope_cos = rope_rotation.get("rope_cos")
rope_sin = rope_rotation.get("rope_sin")
# Pre-norm for visual features
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
# Process visual features
# qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
# assert qkv_x.dtype == torch.bfloat16
# qkv_x = all_to_all_collect_tokens(
# qkv_x, self.num_heads
# ) # (3, B, N, local_h, head_dim)
# Process text features
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
q_y = self.q_norm_y(q_y)
k_y = self.k_norm_y(k_y)
# Split qkv_x into q, k, v
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
q_x = self.q_norm_x(q_x)
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
k_x = self.k_norm_x(k_x)
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
q = torch.cat([q_x, q_y[:, :crop_y]], dim=1).transpose(1, 2)
k = torch.cat([k_x, k_y[:, :crop_y]], dim=1).transpose(1, 2)
v = torch.cat([v_x, v_y[:, :crop_y]], dim=1).transpose(1, 2)
xy = optimized_attention(q,
k,
v, self.num_heads, skip_reshape=True)
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
x = self.proj_x(x)
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
o[:, :y.shape[1]] = y
y = self.proj_y(o)
# print("ox", x)
# print("oy", y)
return x, y
class AsymmetricJointBlock(nn.Module):
def __init__(
self,
hidden_size_x: int,
hidden_size_y: int,
num_heads: int,
*,
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
update_y: bool = True, # Whether to update text tokens in this block.
device: Optional[torch.device] = None,
dtype=None,
operations=None,
**block_kwargs,
):
super().__init__()
self.update_y = update_y
self.hidden_size_x = hidden_size_x
self.hidden_size_y = hidden_size_y
self.mod_x = operations.Linear(hidden_size_x, 4 * hidden_size_x, device=device, dtype=dtype)
if self.update_y:
self.mod_y = operations.Linear(hidden_size_x, 4 * hidden_size_y, device=device, dtype=dtype)
else:
self.mod_y = operations.Linear(hidden_size_x, hidden_size_y, device=device, dtype=dtype)
# Self-attention:
self.attn = AsymmetricAttention(
hidden_size_x,
hidden_size_y,
num_heads=num_heads,
update_y=update_y,
device=device,
dtype=dtype,
operations=operations,
**block_kwargs,
)
# MLP.
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
assert mlp_hidden_dim_x == int(1536 * 8)
self.mlp_x = FeedForward(
in_features=hidden_size_x,
hidden_size=mlp_hidden_dim_x,
multiple_of=256,
ffn_dim_multiplier=None,
device=device,
dtype=dtype,
operations=operations,
)
# MLP for text not needed in last block.
if self.update_y:
mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
self.mlp_y = FeedForward(
in_features=hidden_size_y,
hidden_size=mlp_hidden_dim_y,
multiple_of=256,
ffn_dim_multiplier=None,
device=device,
dtype=dtype,
operations=operations,
)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor,
y: torch.Tensor,
**attn_kwargs,
):
"""Forward pass of a block.
Args:
x: (B, N, dim) tensor of visual tokens
c: (B, dim) tensor of conditioned features
y: (B, L, dim) tensor of text tokens
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
Returns:
x: (B, N, dim) tensor of visual tokens after block
y: (B, L, dim) tensor of text tokens after block
"""
N = x.size(1)
c = F.silu(c)
mod_x = self.mod_x(c)
scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
mod_y = self.mod_y(c)
if self.update_y:
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
else:
scale_msa_y = mod_y
# Self-attention block.
x_attn, y_attn = self.attn(
x,
y,
scale_x=scale_msa_x,
scale_y=scale_msa_y,
**attn_kwargs,
)
assert x_attn.size(1) == N
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
if self.update_y:
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
# MLP block.
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
if self.update_y:
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
return x, y
def ff_block_x(self, x, scale_x, gate_x):
x_mod = modulated_rmsnorm(x, scale_x)
x_res = self.mlp_x(x_mod)
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
return x
def ff_block_y(self, y, scale_y, gate_y):
y_mod = modulated_rmsnorm(y, scale_y)
y_res = self.mlp_y(y_mod)
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
return y
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(
self,
hidden_size,
patch_size,
out_channels,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
):
super().__init__()
self.norm_final = operations.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype
)
self.mod = operations.Linear(hidden_size, 2 * hidden_size, device=device, dtype=dtype)
self.linear = operations.Linear(
hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype
)
def forward(self, x, c):
c = F.silu(c)
shift, scale = self.mod(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class AsymmDiTJoint(nn.Module):
"""
Diffusion model with a Transformer backbone.
Ingests text embeddings instead of a label.
"""
def __init__(
self,
*,
patch_size=2,
in_channels=4,
hidden_size_x=1152,
hidden_size_y=1152,
depth=48,
num_heads=16,
mlp_ratio_x=8.0,
mlp_ratio_y=4.0,
use_t5: bool = False,
t5_feat_dim: int = 4096,
t5_token_length: int = 256,
learn_sigma=True,
patch_embed_bias: bool = True,
timestep_mlp_bias: bool = True,
attend_to_padding: bool = False,
timestep_scale: Optional[float] = None,
use_extended_posenc: bool = False,
posenc_preserve_area: bool = False,
rope_theta: float = 10000.0,
image_model=None,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
**block_kwargs,
):
super().__init__()
self.dtype = dtype
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size_x = hidden_size_x
self.hidden_size_y = hidden_size_y
self.head_dim = (
hidden_size_x // num_heads
) # Head dimension and count is determined by visual.
self.attend_to_padding = attend_to_padding
self.use_extended_posenc = use_extended_posenc
self.posenc_preserve_area = posenc_preserve_area
self.use_t5 = use_t5
self.t5_token_length = t5_token_length
self.t5_feat_dim = t5_feat_dim
self.rope_theta = (
rope_theta # Scaling factor for frequency computation for temporal RoPE.
)
self.x_embedder = PatchEmbed(
patch_size=patch_size,
in_chans=in_channels,
embed_dim=hidden_size_x,
bias=patch_embed_bias,
dtype=dtype,
device=device,
operations=operations
)
# Conditionings
# Timestep
self.t_embedder = TimestepEmbedder(
hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
)
if self.use_t5:
# Caption Pooling (T5)
self.t5_y_embedder = AttentionPool(
t5_feat_dim, num_heads=8, output_dim=hidden_size_x, dtype=dtype, device=device, operations=operations
)
# Dense Embedding Projection (T5)
self.t5_yproj = operations.Linear(
t5_feat_dim, hidden_size_y, bias=True, dtype=dtype, device=device
)
# Initialize pos_frequencies as an empty parameter.
self.pos_frequencies = nn.Parameter(
torch.empty(3, self.num_heads, self.head_dim // 2, dtype=dtype, device=device)
)
assert not self.attend_to_padding
# for depth 48:
# b = 0: AsymmetricJointBlock, update_y=True
# b = 1: AsymmetricJointBlock, update_y=True
# ...
# b = 46: AsymmetricJointBlock, update_y=True
# b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
blocks = []
for b in range(depth):
# Joint multi-modal block
update_y = b < depth - 1
block = AsymmetricJointBlock(
hidden_size_x,
hidden_size_y,
num_heads,
mlp_ratio_x=mlp_ratio_x,
mlp_ratio_y=mlp_ratio_y,
update_y=update_y,
attend_to_padding=attend_to_padding,
device=device,
dtype=dtype,
operations=operations,
**block_kwargs,
)
blocks.append(block)
self.blocks = nn.ModuleList(blocks)
self.final_layer = FinalLayer(
hidden_size_x, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
)
def embed_x(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C=12, T, H, W) tensor of visual tokens
Returns:
x: (B, C=3072, N) tensor of visual tokens with positional embedding.
"""
return self.x_embedder(x) # Convert BcTHW to BCN
def prepare(
self,
x: torch.Tensor,
sigma: torch.Tensor,
t5_feat: torch.Tensor,
t5_mask: torch.Tensor,
):
"""Prepare input and conditioning embeddings."""
# Visual patch embeddings with positional encoding.
T, H, W = x.shape[-3:]
pH, pW = H // self.patch_size, W // self.patch_size
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
assert x.ndim == 3
B = x.size(0)
pH, pW = H // self.patch_size, W // self.patch_size
N = T * pH * pW
assert x.size(1) == N
pos = create_position_matrix(
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
) # (N, 3)
rope_cos, rope_sin = compute_mixed_rotation(
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
) # Each are (N, num_heads, dim // 2)
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
c = c_t + t5_y_pool
y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
return x, c, y_feat, rope_cos, rope_sin
def forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: List[torch.Tensor],
attention_mask: List[torch.Tensor],
num_tokens=256,
packed_indices: Dict[str, torch.Tensor] = None,
rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None,
control=None, **kwargs
):
y_feat = context
y_mask = attention_mask
sigma = timestep
"""Forward pass of DiT.
Args:
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
sigma: (B,) tensor of noise standard deviations
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
"""
B, _, T, H, W = x.shape
x, c, y_feat, rope_cos, rope_sin = self.prepare(
x, sigma, y_feat, y_mask
)
del y_mask
for i, block in enumerate(self.blocks):
x, y_feat = block(
x,
c,
y_feat,
rope_cos=rope_cos,
rope_sin=rope_sin,
crop_y=num_tokens,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
x = rearrange(
x,
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
T=T,
hp=H // self.patch_size,
wp=W // self.patch_size,
p1=self.patch_size,
p2=self.patch_size,
c=self.out_channels,
)
return -x

View File

@@ -0,0 +1,164 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
import collections.abc
import math
from itertools import repeat
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import comfy.ldm.common_dit
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
class TimestepEmbedder(nn.Module):
def __init__(
self,
hidden_size: int,
frequency_embedding_size: int = 256,
*,
bias: bool = True,
timestep_scale: Optional[float] = None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=bias, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=bias, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.timestep_scale = timestep_scale
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
freqs.mul_(-math.log(max_period) / half).exp_()
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t, out_dtype):
if self.timestep_scale is not None:
t = t * self.timestep_scale
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype)
t_emb = self.mlp(t_freq)
return t_emb
class FeedForward(nn.Module):
def __init__(
self,
in_features: int,
hidden_size: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
device: Optional[torch.device] = None,
dtype=None,
operations=None,
):
super().__init__()
# keep parameter count and computation constant compared to standard FFN
hidden_size = int(2 * hidden_size / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_size = int(ffn_dim_multiplier * hidden_size)
hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of)
self.hidden_dim = hidden_size
self.w1 = operations.Linear(in_features, 2 * hidden_size, bias=False, device=device, dtype=dtype)
self.w2 = operations.Linear(hidden_size, in_features, bias=False, device=device, dtype=dtype)
def forward(self, x):
x, gate = self.w1(x).chunk(2, dim=-1)
x = self.w2(F.silu(x) * gate)
return x
class PatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = True,
bias: bool = True,
dynamic_img_pad: bool = False,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.patch_size = to_2tuple(patch_size)
self.flatten = flatten
self.dynamic_img_pad = dynamic_img_pad
self.proj = operations.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
device=device,
dtype=dtype,
)
assert norm_layer is None
self.norm = (
norm_layer(embed_dim, device=device) if norm_layer else nn.Identity()
)
def forward(self, x):
B, _C, T, H, W = x.shape
if not self.dynamic_img_pad:
assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
else:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular')
x = self.proj(x)
# Flatten temporal and spatial dimensions.
if not self.flatten:
raise NotImplementedError("Must flatten output.")
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
x = self.norm(x)
return x
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
self.register_parameter("bias", None)
def forward(self, x):
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)

View File

@@ -0,0 +1,88 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
# import functools
import math
import torch
def centers(start: float, stop, num, dtype=None, device=None):
"""linspace through bin centers.
Args:
start (float): Start of the range.
stop (float): End of the range.
num (int): Number of points.
dtype (torch.dtype): Data type of the points.
device (torch.device): Device of the points.
Returns:
centers (Tensor): Centers of the bins. Shape: (num,).
"""
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
return (edges[:-1] + edges[1:]) / 2
# @functools.lru_cache(maxsize=1)
def create_position_matrix(
T: int,
pH: int,
pW: int,
device: torch.device,
dtype: torch.dtype,
*,
target_area: float = 36864,
):
"""
Args:
T: int - Temporal dimension
pH: int - Height dimension after patchify
pW: int - Width dimension after patchify
Returns:
pos: [T * pH * pW, 3] - position matrix
"""
# Create 1D tensors for each dimension
t = torch.arange(T, dtype=dtype)
# Positionally interpolate to area 36864.
# (3072x3072 frame with 16x16 patches = 192x192 latents).
# This automatically scales rope positions when the resolution changes.
# We use a large target area so the model is more sensitive
# to changes in the learned pos_frequencies matrix.
scale = math.sqrt(target_area / (pW * pH))
w = centers(-pW * scale / 2, pW * scale / 2, pW)
h = centers(-pH * scale / 2, pH * scale / 2, pH)
# Use meshgrid to create 3D grids
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
# Stack and reshape the grids.
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
pos = pos.view(-1, 3) # [T * pH * pW, 3]
pos = pos.to(dtype=dtype, device=device)
return pos
def compute_mixed_rotation(
freqs: torch.Tensor,
pos: torch.Tensor,
):
"""
Project each 3-dim position into per-head, per-head-dim 1D frequencies.
Args:
freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position
pos: [N, 3] - position of each token
num_heads: int
Returns:
freqs_cos: [N, num_heads, num_freqs] - cosine components
freqs_sin: [N, num_heads, num_freqs] - sine components
"""
assert freqs.ndim == 3
freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs)
freqs_cos = torch.cos(freqs_sum)
freqs_sin = torch.sin(freqs_sum)
return freqs_cos, freqs_sin

View File

@@ -0,0 +1,34 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
# Based on Llama3 Implementation.
import torch
def apply_rotary_emb_qk_real(
xqk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
Args:
xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
Can be either just query or just key, or both stacked along some batch or * dim.
freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
Returns:
torch.Tensor: The input tensor with rotary embeddings applied.
"""
# Split the last dimension into even and odd parts
xqk_even = xqk[..., 0::2]
xqk_odd = xqk[..., 1::2]
# Apply rotation
cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
# Interleave the results back into the original shape
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
return out

View File

@@ -0,0 +1,102 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
"""
Pool tokens in x using mask.
NOTE: We assume x does not require gradients.
Args:
x: (B, L, D) tensor of tokens.
mask: (B, L) boolean tensor indicating which tokens are not padding.
Returns:
pooled: (B, D) tensor of pooled tokens.
"""
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
mask = mask[:, :, None].to(dtype=x.dtype)
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
return pooled
class AttentionPool(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
output_dim: int = None,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
):
"""
Args:
spatial_dim (int): Number of tokens in sequence length.
embed_dim (int): Dimensionality of input tokens.
num_heads (int): Number of attention heads.
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
"""
super().__init__()
self.num_heads = num_heads
self.to_kv = operations.Linear(embed_dim, 2 * embed_dim, device=device, dtype=dtype)
self.to_q = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.to_out = operations.Linear(embed_dim, output_dim or embed_dim, device=device, dtype=dtype)
def forward(self, x, mask):
"""
Args:
x (torch.Tensor): (B, L, D) tensor of input tokens.
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
NOTE: We assume x does not require gradients.
Returns:
x (torch.Tensor): (B, D) tensor of pooled tokens.
"""
D = x.size(2)
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
# Average non-padding token features. These will be used as the query.
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
# Concat pooled features to input sequence.
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
# Compute queries, keys, values. Only the mean token is used to create a query.
kv = self.to_kv(x) # (B, L+1, 2 * D)
q = self.to_q(x[:, 0]) # (B, D)
# Extract heads.
head_dim = D // self.num_heads
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
q = q.unsqueeze(2) # (B, H, 1, head_dim)
# Compute attention.
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0.0
) # (B, H, 1, head_dim)
# Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
x = self.to_out(x)
return x

View File

@@ -0,0 +1,480 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import comfy.ops
ops = comfy.ops.disable_weight_init
# import mochi_preview.dit.joint_model.context_parallel as cp
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
class GroupNormSpatial(ops.GroupNorm):
"""
GroupNorm applied per-frame.
"""
def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
B, C, T, H, W = x.shape
x = rearrange(x, "B C T H W -> (B T) C H W")
# Run group norm in chunks.
output = torch.empty_like(x)
for b in range(0, B * T, chunk_size):
output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
class PConv3d(ops.Conv3d):
def __init__(
self,
in_channels,
out_channels,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]],
causal: bool = True,
context_parallel: bool = True,
**kwargs,
):
self.causal = causal
self.context_parallel = context_parallel
kernel_size = cast_tuple(kernel_size, 3)
stride = cast_tuple(stride, 3)
height_pad = (kernel_size[1] - 1) // 2
width_pad = (kernel_size[2] - 1) // 2
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=(1, 1, 1),
padding=(0, height_pad, width_pad),
**kwargs,
)
def forward(self, x: torch.Tensor):
# Compute padding amounts.
context_size = self.kernel_size[0] - 1
if self.causal:
pad_front = context_size
pad_back = 0
else:
pad_front = context_size // 2
pad_back = context_size - pad_front
# Apply padding.
assert self.padding_mode == "replicate" # DEBUG
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
return super().forward(x)
class Conv1x1(ops.Linear):
"""*1x1 Conv implemented with a linear layer."""
def __init__(self, in_features: int, out_features: int, *args, **kwargs):
super().__init__(in_features, out_features, *args, **kwargs)
def forward(self, x: torch.Tensor):
"""Forward pass.
Args:
x: Input tensor. Shape: [B, C, *] or [B, *, C].
Returns:
x: Output tensor. Shape: [B, C', *] or [B, *, C'].
"""
x = x.movedim(1, -1)
x = super().forward(x)
x = x.movedim(-1, 1)
return x
class DepthToSpaceTime(nn.Module):
def __init__(
self,
temporal_expansion: int,
spatial_expansion: int,
):
super().__init__()
self.temporal_expansion = temporal_expansion
self.spatial_expansion = spatial_expansion
# When printed, this module should show the temporal and spatial expansion factors.
def extra_repr(self):
return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"
def forward(self, x: torch.Tensor):
"""Forward pass.
Args:
x: Input tensor. Shape: [B, C, T, H, W].
Returns:
x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
"""
x = rearrange(
x,
"B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
st=self.temporal_expansion,
sh=self.spatial_expansion,
sw=self.spatial_expansion,
)
# cp_rank, _ = cp.get_cp_rank_size()
if self.temporal_expansion > 1: # and cp_rank == 0:
# Drop the first self.temporal_expansion - 1 frames.
# This is because we always want the 3x3x3 conv filter to only apply
# to the first frame, and the first frame doesn't need to be repeated.
assert all(x.shape)
x = x[:, :, self.temporal_expansion - 1 :]
assert all(x.shape)
return x
def norm_fn(
in_channels: int,
affine: bool = True,
):
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
class ResBlock(nn.Module):
"""Residual block that preserves the spatial dimensions."""
def __init__(
self,
channels: int,
*,
affine: bool = True,
attn_block: Optional[nn.Module] = None,
padding_mode: str = "replicate",
causal: bool = True,
):
super().__init__()
self.channels = channels
assert causal
self.stack = nn.Sequential(
norm_fn(channels, affine=affine),
nn.SiLU(inplace=True),
PConv3d(
in_channels=channels,
out_channels=channels,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding_mode=padding_mode,
bias=True,
# causal=causal,
),
norm_fn(channels, affine=affine),
nn.SiLU(inplace=True),
PConv3d(
in_channels=channels,
out_channels=channels,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding_mode=padding_mode,
bias=True,
# causal=causal,
),
)
self.attn_block = attn_block if attn_block else nn.Identity()
def forward(self, x: torch.Tensor):
"""Forward pass.
Args:
x: Input tensor. Shape: [B, C, T, H, W].
"""
residual = x
x = self.stack(x)
x = x + residual
del residual
return self.attn_block(x)
class CausalUpsampleBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks: int,
*,
temporal_expansion: int = 2,
spatial_expansion: int = 2,
**block_kwargs,
):
super().__init__()
blocks = []
for _ in range(num_res_blocks):
blocks.append(block_fn(in_channels, **block_kwargs))
self.blocks = nn.Sequential(*blocks)
self.temporal_expansion = temporal_expansion
self.spatial_expansion = spatial_expansion
# Change channels in the final convolution layer.
self.proj = Conv1x1(
in_channels,
out_channels * temporal_expansion * (spatial_expansion**2),
)
self.d2st = DepthToSpaceTime(
temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
)
def forward(self, x):
x = self.blocks(x)
x = self.proj(x)
x = self.d2st(x)
return x
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
assert has_attention is False #NOTE: if this is ever true add back the attention code.
attn_block = None #AttentionBlock(channels) if has_attention else None
return ResBlock(
channels, affine=True, attn_block=attn_block, **block_kwargs
)
class DownsampleBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks,
*,
temporal_reduction=2,
spatial_reduction=2,
**block_kwargs,
):
"""
Downsample block for the VAE encoder.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
num_res_blocks: Number of residual blocks.
temporal_reduction: Temporal reduction factor.
spatial_reduction: Spatial reduction factor.
"""
super().__init__()
layers = []
# Change the channel count in the strided convolution.
# This lets the ResBlock have uniform channel count,
# as in ConvNeXt.
assert in_channels != out_channels
layers.append(
PConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
padding_mode="replicate",
bias=True,
)
)
for _ in range(num_res_blocks):
layers.append(block_fn(out_channels, **block_kwargs))
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
num_freqs = (stop - start) // step
assert inputs.ndim == 5
C = inputs.size(1)
# Create Base 2 Fourier features.
freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
assert num_freqs == len(freqs)
w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
C = inputs.shape[1]
w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1]
# Interleaved repeat of input channels to match w.
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
# Scale channels by frequency.
h = w * h
return torch.cat(
[
inputs,
torch.sin(h),
torch.cos(h),
],
dim=1,
)
class FourierFeatures(nn.Module):
def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
super().__init__()
self.start = start
self.stop = stop
self.step = step
def forward(self, inputs):
"""Add Fourier features to inputs.
Args:
inputs: Input tensor. Shape: [B, C, T, H, W]
Returns:
h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
"""
return add_fourier_features(inputs, self.start, self.stop, self.step)
class Decoder(nn.Module):
def __init__(
self,
*,
out_channels: int = 3,
latent_dim: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
temporal_expansions: Optional[List[int]] = None,
spatial_expansions: Optional[List[int]] = None,
has_attention: List[bool],
output_norm: bool = True,
nonlinearity: str = "silu",
output_nonlinearity: str = "silu",
causal: bool = True,
**block_kwargs,
):
super().__init__()
self.input_channels = latent_dim
self.base_channels = base_channels
self.channel_multipliers = channel_multipliers
self.num_res_blocks = num_res_blocks
self.output_nonlinearity = output_nonlinearity
assert nonlinearity == "silu"
assert causal
ch = [mult * base_channels for mult in channel_multipliers]
self.num_up_blocks = len(ch) - 1
assert len(num_res_blocks) == self.num_up_blocks + 2
blocks = []
first_block = [
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
] # Input layer.
# First set of blocks preserve channel count.
for _ in range(num_res_blocks[-1]):
first_block.append(
block_fn(
ch[-1],
has_attention=has_attention[-1],
causal=causal,
**block_kwargs,
)
)
blocks.append(nn.Sequential(*first_block))
assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2
upsample_block_fn = CausalUpsampleBlock
for i in range(self.num_up_blocks):
block = upsample_block_fn(
ch[-i - 1],
ch[-i - 2],
num_res_blocks=num_res_blocks[-i - 2],
has_attention=has_attention[-i - 2],
temporal_expansion=temporal_expansions[-i - 1],
spatial_expansion=spatial_expansions[-i - 1],
causal=causal,
**block_kwargs,
)
blocks.append(block)
assert not output_norm
# Last block. Preserve channel count.
last_block = []
for _ in range(num_res_blocks[0]):
last_block.append(
block_fn(
ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
)
)
blocks.append(nn.Sequential(*last_block))
self.blocks = nn.ModuleList(blocks)
self.output_proj = Conv1x1(ch[0], out_channels)
def forward(self, x):
"""Forward pass.
Args:
x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].
Returns:
x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
T + 1 = (t - 1) * 4.
H = h * 16, W = w * 16.
"""
for block in self.blocks:
x = block(x)
if self.output_nonlinearity == "silu":
x = F.silu(x, inplace=not self.training)
else:
assert (
not self.output_nonlinearity
) # StyleGAN3 omits the to-RGB nonlinearity.
return self.output_proj(x).contiguous()
class VideoVAE(nn.Module):
def __init__(self):
super().__init__()
self.encoder = None #TODO once the model releases
self.decoder = Decoder(
out_channels=3,
base_channels=128,
channel_multipliers=[1, 2, 4, 6],
temporal_expansions=[1, 2, 3],
spatial_expansions=[2, 2, 2],
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
has_attention=[False, False, False, False, False],
padding_mode="replicate",
output_norm=False,
nonlinearity="silu",
output_nonlinearity="silu",
causal=True,
)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)

View File

@@ -372,7 +372,7 @@ class HunYuanDiT(nn.Module):
for layer, block in enumerate(self.blocks):
if layer > self.depth // 2:
if controls is not None:
skip = skips.pop() + controls.pop()
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)

View File

@@ -358,7 +358,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask)
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
if skip_reshape:
q, k, v = map(

View File

@@ -1,11 +1,11 @@
import logging
import math
from typing import Dict, Optional
from typing import Dict, Optional, List
import numpy as np
import torch
import torch.nn as nn
from .. import attention
from ..attention import optimized_attention
from einops import rearrange, repeat
from .util import timestep_embedding
import comfy.ops
@@ -97,7 +97,7 @@ class PatchEmbed(nn.Module):
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
# B, C, H, W = x.shape
# if self.img_size is not None:
# if self.strict_img_size:
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
@@ -266,8 +266,6 @@ def split_qkv(qkv, head_dim):
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
return qkv[0], qkv[1], qkv[2]
def optimized_attention(qkv, num_heads):
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
class SelfAttention(nn.Module):
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
@@ -326,9 +324,9 @@ class SelfAttention(nn.Module):
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv = self.pre_attention(x)
q, k, v = self.pre_attention(x)
x = optimized_attention(
qkv, num_heads=self.num_heads
q, k, v, heads=self.num_heads
)
x = self.post_attention(x)
return x
@@ -355,29 +353,9 @@ class RMSNorm(torch.nn.Module):
else:
self.register_parameter("weight", None)
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
x = self._norm(x)
if self.learnable_scale:
return x * self.weight.to(device=x.device, dtype=x.dtype)
else:
return x
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
class SwiGLUFeedForward(nn.Module):
@@ -437,6 +415,7 @@ class DismantledBlock(nn.Module):
scale_mod_only: bool = False,
swiglu: bool = False,
qk_norm: Optional[str] = None,
x_block_self_attn: bool = False,
dtype=None,
device=None,
operations=None,
@@ -460,6 +439,24 @@ class DismantledBlock(nn.Module):
device=device,
operations=operations
)
if x_block_self_attn:
assert not pre_only
assert not scale_mod_only
self.x_block_self_attn = True
self.attn2 = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=False,
qk_norm=qk_norm,
rmsnorm=rmsnorm,
dtype=dtype,
device=device,
operations=operations
)
else:
self.x_block_self_attn = False
if not pre_only:
if not rmsnorm:
self.norm2 = operations.LayerNorm(
@@ -486,7 +483,11 @@ class DismantledBlock(nn.Module):
multiple_of=256,
)
self.scale_mod_only = scale_mod_only
if not scale_mod_only:
if x_block_self_attn:
assert not pre_only
assert not scale_mod_only
n_mods = 9
elif not scale_mod_only:
n_mods = 6 if not pre_only else 2
else:
n_mods = 4 if not pre_only else 1
@@ -547,14 +548,64 @@ class DismantledBlock(nn.Module):
)
return x
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert self.x_block_self_attn
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
shift_msa2,
scale_msa2,
gate_msa2,
) = self.adaLN_modulation(c).chunk(9, dim=1)
x_norm = self.norm1(x)
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
return qkv, qkv2, (
x,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
gate_msa2,
)
def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2):
assert not self.pre_only
attn1 = self.attn.post_attention(attn)
attn2 = self.attn2.post_attention(attn2)
out1 = gate_msa.unsqueeze(1) * attn1
out2 = gate_msa2.unsqueeze(1) * attn2
x = x + out1
x = x + out2
x = x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
return x
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert not self.pre_only
qkv, intermediates = self.pre_attention(x, c)
attn = optimized_attention(
qkv,
num_heads=self.attn.num_heads,
)
return self.post_attention(attn, *intermediates)
if self.x_block_self_attn:
qkv, qkv2, intermediates = self.pre_attention_x(x, c)
attn, _ = optimized_attention(
qkv[0], qkv[1], qkv[2],
num_heads=self.attn.num_heads,
)
attn2, _ = optimized_attention(
qkv2[0], qkv2[1], qkv2[2],
num_heads=self.attn2.num_heads,
)
return self.post_attention_x(attn, attn2, *intermediates)
else:
qkv, intermediates = self.pre_attention(x, c)
attn = optimized_attention(
qkv[0], qkv[1], qkv[2],
heads=self.attn.num_heads,
)
return self.post_attention(attn, *intermediates)
def block_mixing(*args, use_checkpoint=True, **kwargs):
@@ -569,7 +620,10 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
def _block_mixing(context, x, context_block, x_block, c):
context_qkv, context_intermediates = context_block.pre_attention(context, c)
x_qkv, x_intermediates = x_block.pre_attention(x, c)
if x_block.x_block_self_attn:
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
else:
x_qkv, x_intermediates = x_block.pre_attention(x, c)
o = []
for t in range(3):
@@ -577,8 +631,8 @@ def _block_mixing(context, x, context_block, x_block, c):
qkv = tuple(o)
attn = optimized_attention(
qkv,
num_heads=x_block.attn.num_heads,
qkv[0], qkv[1], qkv[2],
heads=x_block.attn.num_heads,
)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
@@ -590,7 +644,14 @@ def _block_mixing(context, x, context_block, x_block, c):
else:
context = None
x = x_block.post_attention(x_attn, *x_intermediates)
if x_block.x_block_self_attn:
attn2 = optimized_attention(
x_qkv2[0], x_qkv2[1], x_qkv2[2],
heads=x_block.attn2.num_heads,
)
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
else:
x = x_block.post_attention(x_attn, *x_intermediates)
return context, x
@@ -605,8 +666,13 @@ class JointBlock(nn.Module):
super().__init__()
pre_only = kwargs.pop("pre_only")
qk_norm = kwargs.pop("qk_norm", None)
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
self.x_block = DismantledBlock(*args,
pre_only=False,
qk_norm=qk_norm,
x_block_self_attn=x_block_self_attn,
**kwargs)
def forward(self, *args, **kwargs):
return block_mixing(
@@ -662,7 +728,7 @@ class SelfAttentionContext(nn.Module):
def forward(self, x):
qkv = self.qkv(x)
q, k, v = split_qkv(qkv, self.dim_head)
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads)
return self.proj(x)
class ContextProcessorBlock(nn.Module):
@@ -721,9 +787,12 @@ class MMDiT(nn.Module):
qk_norm: Optional[str] = None,
qkv_bias: bool = True,
context_processor_layers = None,
x_block_self_attn: bool = False,
x_block_self_attn_layers: Optional[List[int]] = [],
context_size = 4096,
num_blocks = None,
final_layer = True,
skip_blocks = False,
dtype = None, #TODO
device = None,
operations = None,
@@ -738,6 +807,7 @@ class MMDiT(nn.Module):
self.pos_embed_scaling_factor = pos_embed_scaling_factor
self.pos_embed_offset = pos_embed_offset
self.pos_embed_max_size = pos_embed_max_size
self.x_block_self_attn_layers = x_block_self_attn_layers
# hidden_size = default(hidden_size, 64 * depth)
# num_heads = default(num_heads, hidden_size // 64)
@@ -795,26 +865,28 @@ class MMDiT(nn.Module):
self.pos_embed = None
self.use_checkpoint = use_checkpoint
self.joint_blocks = nn.ModuleList(
[
JointBlock(
self.hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=(i == num_blocks - 1) and final_layer,
rmsnorm=rmsnorm,
scale_mod_only=scale_mod_only,
swiglu=swiglu,
qk_norm=qk_norm,
dtype=dtype,
device=device,
operations=operations
)
for i in range(num_blocks)
]
)
if not skip_blocks:
self.joint_blocks = nn.ModuleList(
[
JointBlock(
self.hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=(i == num_blocks - 1) and final_layer,
rmsnorm=rmsnorm,
scale_mod_only=scale_mod_only,
swiglu=swiglu,
qk_norm=qk_norm,
x_block_self_attn=(i in self.x_block_self_attn_layers) or x_block_self_attn,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(num_blocks)
]
)
if final_layer:
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
@@ -877,7 +949,9 @@ class MMDiT(nn.Module):
c_mod: torch.Tensor,
context: Optional[torch.Tensor] = None,
control = None,
transformer_options = {},
) -> torch.Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if self.register_length > 0:
context = torch.cat(
(
@@ -889,14 +963,25 @@ class MMDiT(nn.Module):
# context is B, L', D
# x is B, L, D
blocks_replace = patches_replace.get("dit", {})
blocks = len(self.joint_blocks)
for i in range(blocks):
context, x = self.joint_blocks[i](
context,
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
context = out["txt"]
x = out["img"]
else:
context, x = self.joint_blocks[i](
context,
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
)
if control is not None:
control_o = control.get("output")
if i < len(control_o):
@@ -914,6 +999,7 @@ class MMDiT(nn.Module):
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
control = None,
transformer_options = {},
) -> torch.Tensor:
"""
Forward pass of DiT.
@@ -935,7 +1021,7 @@ class MMDiT(nn.Module):
if context is not None:
context = self.context_embedder(context)
x = self.forward_core_with_concat(x, c, context, control)
x = self.forward_core_with_concat(x, c, context, control, transformer_options)
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
return x[:,:,:hw[-2],:hw[-1]]
@@ -949,7 +1035,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
control = None,
transformer_options = {},
**kwargs,
) -> torch.Tensor:
return super().forward(x, timesteps, context=context, y=y, control=control)
return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)

View File

@@ -842,6 +842,11 @@ class UNetModel(nn.Module):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if "emb_patch" in transformer_patches:
patch = transformer_patches["emb_patch"]
for p in patch:
emb = p(emb, self.model_channels, transformer_options)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)

View File

@@ -16,8 +16,12 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import comfy.utils
import comfy.model_management
import comfy.model_base
import logging
import torch
LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1",
@@ -197,9 +201,13 @@ def load_lora(lora, to_load):
def model_lora_keys_clip(model, key_map={}):
sdk = model.state_dict().keys()
for k in sdk:
if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False
clip_g_present = False
for b in range(32): #TODO: clean up
for c in LORA_CLIP_MAP:
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
@@ -223,6 +231,7 @@ def model_lora_keys_clip(model, key_map={}):
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
clip_g_present = True
if clip_l_present:
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k
@@ -238,10 +247,18 @@ def model_lora_keys_clip(model, key_map={}):
for k in sdk:
if k.endswith(".weight"):
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
l_key = k[len("t5xxl.transformer."):-len(".weight")]
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
t5_index = 1
if clip_g_present:
t5_index += 1
if clip_l_present:
t5_index += 1
if t5_index == 2:
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
t5_index += 1
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
@@ -277,6 +294,7 @@ def model_lora_keys_unet(model, key_map={}):
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
key_lora = k[:-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = unet_key
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
diffusers_lora_prefix = ["", "unet."]
for p in diffusers_lora_prefix:
@@ -299,6 +317,10 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
key_map[key_lora] = to
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
key_map[key_lora] = to
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
@@ -318,7 +340,256 @@ def model_lora_keys_unet(model, key_map={}):
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
key_map[key_lora] = to
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
return key_map
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + function(lora_diff).type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
"""
Pad a tensor to a new shape with zeros.
Args:
tensor (torch.Tensor): The original tensor to be padded.
new_shape (List[int]): The desired shape of the padded tensor.
Returns:
torch.Tensor: A new tensor padded with zeros to the specified shape.
Note:
If the new shape is smaller than the original tensor in any dimension,
the original tensor will be truncated in that dimension.
"""
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
if len(new_shape) != len(tensor.shape):
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
# Create a new tensor filled with zeros
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
# Create slicing tuples for both tensors
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
# Copy the original tensor into the new tensor
padded_tensor[new_slices] = tensor[orig_slices]
return padded_tensor
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
for p in patches:
strength = p[0]
v = p[1]
strength_model = p[2]
offset = p[3]
function = p[4]
if function is None:
function = lambda a: a
old_weight = None
if offset is not None:
old_weight = weight
weight = weight.narrow(offset[0], offset[1], offset[2])
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
diff: torch.Tensor = v[0]
# An extra flag to pad the weight if the diff's shape is larger than the weight
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
if do_pad_weight and diff.shape != weight.shape:
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
weight = pad_tensor_to_shape(weight, diff.shape)
if strength != 0.0:
if diff.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
else:
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
dora_scale = v[5]
old_glora = False
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
rank = v[0].shape[0]
old_glora = True
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
pass
else:
old_glora = False
rank = v[1].shape[0]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
if v[4] is not None:
alpha = v[4] / rank
else:
alpha = 1.0
try:
if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
else:
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None:
weight = old_weight
return weight

View File

@@ -24,6 +24,7 @@ from comfy.ldm.cascade.stage_b import StageB
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
import comfy.ldm.genmo.joint_model.asymm_models_joint
import comfy.ldm.aura.mmdit
import comfy.ldm.hydit.models
import comfy.ldm.audio.dit
@@ -96,10 +97,8 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
if self.manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@@ -247,6 +246,10 @@ class BaseModel(torch.nn.Module):
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:
@@ -716,3 +719,18 @@ class Flux(BaseModel):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
return out
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out

View File

@@ -70,6 +70,11 @@ def detect_unet_config(state_dict, key_prefix):
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
if context_processor in state_dict_keys:
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
unet_config["x_block_self_attn_layers"] = []
for key in state_dict_keys:
if key.startswith('{}joint_blocks.'.format(key_prefix)) and key.endswith('.x_block.attn2.qkv.weight'):
layer = key[len('{}joint_blocks.'.format(key_prefix)):-len('.x_block.attn2.qkv.weight')]
unet_config["x_block_self_attn_layers"].append(int(layer))
return unet_config
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
@@ -145,6 +150,34 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
dit_config = {}
dit_config["image_model"] = "mochi_preview"
dit_config["depth"] = 48
dit_config["patch_size"] = 2
dit_config["num_heads"] = 24
dit_config["hidden_size_x"] = 3072
dit_config["hidden_size_y"] = 1536
dit_config["mlp_ratio_x"] = 4.0
dit_config["mlp_ratio_y"] = 4.0
dit_config["learn_sigma"] = False
dit_config["in_channels"] = 12
dit_config["qk_norm"] = True
dit_config["qkv_bias"] = False
dit_config["out_bias"] = True
dit_config["attn_drop"] = 0.0
dit_config["patch_embed_bias"] = True
dit_config["posenc_preserve_area"] = True
dit_config["timestep_mlp_bias"] = True
dit_config["attend_to_padding"] = False
dit_config["timestep_scale"] = 1000.0
dit_config["use_t5"] = True
dit_config["t5_feat_dim"] = 4096
dit_config["t5_token_length"] = 256
dit_config["rope_theta"] = 10000.0
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -286,9 +319,15 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
return None
model_config = model_config_from_unet_config(unet_config, state_dict)
if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config)
else:
return model_config
model_config = comfy.supported_models_base.BASE(unet_config)
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
if scaled_fp8_weight is not None:
model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn
return model_config
def unet_prefix_from_state_dict(state_dict):
candidates = ["model.diffusion_model.", #ldm/sgm models
@@ -472,9 +511,15 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p]
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
for unet_config in supported_models:
matches = True

View File

@@ -44,9 +44,15 @@ cpu_state = CPUState.GPU
total_vram = 0
lowvram_available = True
xpu_available = False
torch_version = ""
try:
torch_version = torch.version.__version__
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
except:
pass
lowvram_available = True
if args.deterministic:
logging.info("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True)
@@ -66,10 +72,10 @@ if args.directml is not None:
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
_ = torch.xpu.device_count()
xpu_available = torch.xpu.is_available()
except:
pass
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
try:
if torch.backends.mps.is_available():
@@ -139,7 +145,7 @@ total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try:
logging.info("pytorch version: {}".format(torch.version.__version__))
logging.info("pytorch version: {}".format(torch_version))
except:
pass
@@ -189,7 +195,6 @@ VAE_DTYPES = [torch.float32]
try:
if is_nvidia():
torch_version = torch.version.__version__
if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
@@ -315,17 +320,15 @@ class LoadedModel:
self.model_use_more_vram(use_more_vram)
else:
try:
if lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e
if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
with torch.no_grad():
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
self.weights_loaded = True
return self.real_model
@@ -338,8 +341,9 @@ class LoadedModel:
def model_unload(self, memory_to_free=None, unpatch_weights=True):
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
self.model.partially_unload(self.model.offload_device, memory_to_free)
return False
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed >= memory_to_free:
return False
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
@@ -366,8 +370,21 @@ def offloaded_memory(loaded_models, device):
offloaded_mem += m.model_offloaded_memory()
return offloaded_mem
WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS:
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
def extra_reserved_memory():
return EXTRA_RESERVED_VRAM
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 1.2
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = []
@@ -391,6 +408,8 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
if not force_unload:
if unload_weights_only and unload_weight == False:
return None
else:
unload_weight = True
for i in to_unload:
logging.debug("unload clone {} {}".format(i, unload_weight))
@@ -407,7 +426,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
for x in sorted(can_unload):
@@ -434,15 +453,15 @@ def free_memory(memory_required, device, keep_loaded=[]):
soft_empty_cache()
return unloaded_models
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
global vram_state
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024)
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
if minimum_memory_required is None:
minimum_memory_required = extra_mem
else:
minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024)
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
models = set(models)
@@ -513,7 +532,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
else:
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
@@ -552,7 +571,9 @@ def loaded_models(only_currently_used=False):
def cleanup_models(keep_clone_weights_loaded=False):
to_delete = []
for i in range(len(current_loaded_models)):
if sys.getrefcount(current_loaded_models[i].model) <= 2:
#TODO: very fragile function needs improvement
num_refs = sys.getrefcount(current_loaded_models[i].model)
if num_refs <= 2:
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
@@ -605,6 +626,8 @@ def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if model_params < 0:
model_params = 1000000000000000000000
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:
@@ -624,6 +647,9 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
pass
if fp8_dtype is not None:
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
return fp8_dtype
free_model_memory = maximum_vram_for_weights(device)
if model_params * 2 > free_model_memory:
return fp8_dtype
@@ -659,6 +685,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
if bf16_supported and weight_dtype == torch.bfloat16:
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
for dt in supported_dtypes:
if dt == torch.float16 and fp16_supported:
return torch.float16
@@ -816,27 +843,21 @@ def force_channels_last():
#TODO
return False
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
device_supports_cast = True
elif tensor.dtype == torch.bfloat16:
if hasattr(device, 'type') and device.type.startswith("cuda"):
device_supports_cast = True
elif is_intel_xpu():
device_supports_cast = True
non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
non_blocking = device_should_use_non_blocking(device)
if device_supports_cast:
if copy:
if tensor.device == device:
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
def xformers_enabled():
global directml_enabled
@@ -874,7 +895,8 @@ def pytorch_attention_flash_attention():
def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
try:
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
if (14, 5) <= macos_version <= (15, 0, 1): # black image bug on recent versions of macOS
upcast = True
except:
pass
@@ -970,23 +992,23 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if torch.version.hip:
return True
props = torch.cuda.get_device_properties("cuda")
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
if props.major < 6:
return False
fp16_works = False
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
#when the model doesn't actually fit on the card
#TODO: actually test if GP106 and others have the same type of behavior
#FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
for x in nvidia_10_series:
if x in props.name.lower():
fp16_works = True
if WINDOWS or manual_cast:
return True
else:
return False #weird linux behavior where fp32 is faster
if fp16_works or manual_cast:
if manual_cast:
free_model_memory = maximum_vram_for_weights(device)
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
@@ -1026,7 +1048,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_intel_xpu():
return True
props = torch.cuda.get_device_properties("cuda")
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
@@ -1039,6 +1061,27 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
def supports_fp8_compute(device=None):
if not is_nvidia():
return False
props = torch.cuda.get_device_properties(device)
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
return False
if WINDOWS:
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
return False
return True
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:

View File

@@ -22,32 +22,26 @@ import inspect
import logging
import uuid
import collections
import math
import comfy.utils
import comfy.float
import comfy.model_management
from comfy.types import UnetWrapperFunction
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight
import comfy.lora
from comfy.comfy_types import UnetWrapperFunction
def string_to_seed(data):
crc = 0xFFFFFFFF
for byte in data:
if isinstance(byte, str):
byte = ord(byte)
crc ^= byte
for _ in range(8):
if crc & 1:
crc = (crc >> 1) ^ 0xEDB88320
else:
crc >>= 1
return crc ^ 0xFFFFFFFF
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
@@ -90,12 +84,40 @@ def wipe_lowvram_weight(m):
m.bias_function = None
class LowVramPatch:
def __init__(self, key, model_patcher):
def __init__(self, key, patches):
self.key = key
self.model_patcher = model_patcher
self.patches = patches
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
intermediate_dtype = weight.dtype
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
def get_key_weight(model, key):
set_func = None
convert_func = None
op_keys = key.rsplit('.', 1)
if len(op_keys) < 2:
weight = comfy.utils.get_attr(model, key)
else:
op = comfy.utils.get_attr(model, op_keys[0])
try:
set_func = getattr(op, "set_{}".format(op_keys[1]))
except AttributeError:
pass
try:
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
except AttributeError:
pass
weight = getattr(op, op_keys[1])
if convert_func is not None:
weight = comfy.utils.get_attr(model, key)
return weight, set_func, convert_func
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
@@ -290,17 +312,23 @@ class ModelPatcher:
return list(p)
def get_key_patches(self, filter_prefix=None):
comfy.model_management.unload_model_clones(self)
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
if filter_prefix is not None:
if not k.startswith(filter_prefix):
continue
bk = self.backup.get(k, None)
weight, set_func, convert_func = get_key_weight(self.model, k)
if bk is not None:
weight = bk.weight
if convert_func is None:
convert_func = lambda a, **kwargs: a
if k in self.patches:
p[k] = [model_sd[k]] + self.patches[k]
p[k] = [(weight, convert_func)] + self.patches[k]
else:
p[k] = (model_sd[k],)
p[k] = [(weight, convert_func)]
return p
def model_state_dict(self, filter_prefix=None):
@@ -316,8 +344,7 @@ class ModelPatcher:
if key not in self.patches:
return
weight = comfy.utils.get_attr(self.model, key)
weight, set_func, convert_func = get_key_weight(self.model, key)
inplace_update = self.weight_inplace_update or inplace_update
if key not in self.backup:
@@ -327,47 +354,42 @@ class ModelPatcher:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True)
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
if patch_weights:
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
continue
self.patch_weight_to_device(key, device_to)
if device_to is not None:
self.model.to(device_to)
self.model.device = device_to
self.model.model_loaded_weight_memory = self.model_size()
return self.model
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
loading = []
for n, m in self.model.named_modules():
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
loading.append((comfy.model_management.module_size(m), n, m))
load_completely = []
loading.sort(reverse=True)
for x in loading:
n = x[1]
m = x[2]
module_mem = x[0]
lowvram_weight = False
if not full_load and hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
lowvram_counter += 1
if m.comfy_cast_weights:
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
weight_key = "{}.weight".format(n)
@@ -378,13 +400,13 @@ class ModelPatcher:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = LowVramPatch(weight_key, self)
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = LowVramPatch(bias_key, self)
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
@@ -395,205 +417,56 @@ class ModelPatcher:
wipe_lowvram_weight(m)
if hasattr(m, "weight"):
mem_counter += comfy.model_management.module_size(m)
param = list(m.parameters())
if len(param) > 0:
weight = param[0]
if weight.device == device_to:
continue
mem_counter += module_mem
load_completely.append((module_mem, n, m))
weight_to = None
if full_load:#TODO
weight_to = device_to
self.patch_weight_to_device(weight_key, device_to=weight_to) #TODO: speed this up without OOM
self.patch_weight_to_device(bias_key, device_to=weight_to)
m.to(device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
load_completely.sort(reverse=True)
for x in load_completely:
n = x[1]
m = x[2]
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if hasattr(m, "comfy_patched_weights"):
if m.comfy_patched_weights == True:
continue
self.patch_weight_to_device(weight_key, device_to=device_to)
self.patch_weight_to_device(bias_key, device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
for x in load_completely:
x[2].to(device_to)
if lowvram_counter > 0:
logging.info("loaded partially {} {}".format(lowvram_model_memory / (1024 * 1024), patch_counter))
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
mem_counter = self.model_size()
self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
self.patch_model(device_to, patch_weights=False)
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
if lowvram_model_memory == 0:
full_load = True
else:
full_load = False
if load_weights:
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
strength = p[0]
v = p[1]
strength_model = p[2]
offset = p[3]
function = p[4]
if function is None:
function = lambda a: a
old_weight = None
if offset is not None:
old_weight = weight
weight = weight.narrow(offset[0], offset[1], offset[2])
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0]
if strength != 0.0:
if w1.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
dora_scale = v[5]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None:
weight = old_weight
return weight
def unpatch_model(self, device_to=None, unpatch_weights=True):
if unpatch_weights:
if self.model.model_lowvram:
@@ -619,6 +492,10 @@ class ModelPatcher:
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
for m in self.model.modules():
if hasattr(m, "comfy_patched_weights"):
del m.comfy_patched_weights
keys = list(self.object_patches_backup.keys())
for k in keys:
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
@@ -628,40 +505,47 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0):
memory_freed = 0
patch_counter = 0
unload_list = []
for n, m in list(self.model.named_modules())[::-1]:
if memory_to_free < memory_freed:
break
for n, m in self.model.named_modules():
shift_lowvram = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
unload_list.append((module_mem, n, m))
unload_list.sort()
for unload in unload_list:
if memory_to_free < memory_freed:
break
module_mem = unload[0]
n = unload[1]
m = unload[2]
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if m.weight is not None and m.weight.device != device_to:
for key in [weight_key, bias_key]:
bk = self.backup.get(key, None)
if bk is not None:
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:
comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
for key in [weight_key, bias_key]:
bk = self.backup.get(key, None)
if bk is not None:
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:
comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
m.to(device_to)
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self)
patch_counter += 1
m.to(device_to)
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
memory_freed += module_mem
logging.debug("freed {}".format(n))
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
logging.debug("freed {}".format(n))
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
@@ -670,15 +554,19 @@ class ModelPatcher:
def partially_load(self, device_to, extra_memory=0):
self.unpatch_model(unpatch_weights=False)
self.patch_model(patch_weights=False)
self.patch_model(load_weights=False)
full_load = False
if self.model.model_lowvram == False:
return 0
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
full_load = True
current_used = self.model.model_loaded_weight_memory
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
return self.model.model_loaded_weight_memory - current_used
def current_loaded_device(self):
return self.model.device
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)

View File

@@ -18,29 +18,34 @@
import torch
import comfy.model_management
from comfy.cli_args import args
import comfy.float
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
def cast_to(weight, dtype=None, device=None, non_blocking=False):
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_to_input(weight, input, non_blocking=False):
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
def cast_bias_weight(s, input=None, dtype=None, device=None):
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = None
non_blocking = comfy.model_management.device_should_use_non_blocking(device)
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
if s.bias_function is not None:
has_function = s.bias_function is not None
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
bias = s.bias_function(bias)
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
if s.weight_function is not None:
has_function = s.weight_function is not None
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
weight = s.weight_function(weight)
return weight, bias
@@ -238,3 +243,124 @@ class manual_cast(disable_weight_init):
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True
def fp8_linear(self, input):
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t()
scale_weight = self.scale_weight
scale_input = self.scale_input
if scale_weight is None:
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
else:
scale_weight = scale_weight.to(input.device)
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = input.reshape(-1, input.shape[2]).to(dtype)
else:
scale_input = scale_input.to(input.device)
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
if isinstance(o, tuple):
o = o[0]
if tensor_2d:
return o.reshape(input.shape[0], -1)
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
return None
class fp8_ops(manual_cast):
class Linear(manual_cast.Linear):
def reset_parameters(self):
self.scale_weight = None
self.scale_input = None
return None
def forward_comfy_cast_weights(self, input):
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear):
def __init__(self, *args, **kwargs):
if override_dtype is not None:
kwargs['dtype'] = override_dtype
super().__init__(*args, **kwargs)
def reset_parameters(self):
if not hasattr(self, 'scale_weight'):
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
if not scale_input:
self.scale_input = None
if not hasattr(self, 'scale_input'):
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
return None
def forward_comfy_cast_weights(self, input):
if fp8_matrix_mult:
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
if weight.numel() < input.numel(): #TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight
else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if inplace_update:
self.weight.data.copy_(weight)
else:
self.weight = torch.nn.Parameter(weight, requires_grad=False)
return scaled_fp8_op
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
return manual_cast

View File

@@ -6,7 +6,7 @@ from comfy import model_management
import math
import logging
import comfy.sampler_helpers
import scipy
import scipy.stats
import numpy
def get_area_and_mult(conds, x_in, timestep_in):
@@ -358,11 +358,35 @@ def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
sigs = []
last_t = -1
for t in ts:
sigs += [float(model_sampling.sigmas[int(t)])]
if t != last_t:
sigs += [float(model_sampling.sigmas[int(t)])]
last_t = t
sigs += [0.0]
return torch.FloatTensor(sigs)
# from: https://github.com/genmoai/models/blob/main/src/mochi_preview/infer.py#L41
def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, linear_steps=None):
if steps == 1:
sigma_schedule = [1.0, 0.0]
else:
if linear_steps is None:
linear_steps = steps // 2
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
threshold_noise_step_diff = linear_steps - threshold_noise * steps
quadratic_steps = steps - linear_steps
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
const = quadratic_coef * (linear_steps ** 2)
quadratic_sigma_schedule = [
quadratic_coef * (i ** 2) + linear_coef * i + const
for i in range(linear_steps, steps)
]
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
sigma_schedule = [1.0 - x for x in sigma_schedule]
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
@@ -570,8 +594,8 @@ class Sampler:
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis"]
class KSAMPLER(Sampler):
@@ -729,7 +753,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas(model_sampling, scheduler_name, steps):
@@ -747,6 +771,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
elif scheduler_name == "beta":
sigmas = beta_scheduler(model_sampling, steps)
elif scheduler_name == "linear_quadratic":
sigmas = linear_quadratic_schedule(model_sampling, steps)
else:
logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas

View File

@@ -7,6 +7,7 @@ from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model
import yaml
import comfy.utils
@@ -24,11 +25,12 @@ import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.long_clipl
import comfy.text_encoders.genmo
import comfy.model_patcher
import comfy.lora
import comfy.t2i_adapter.adapter
import comfy.supported_models_base
import comfy.taesd.taesd
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
@@ -62,18 +64,23 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0):
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
if no_init:
return
params = target.params.copy()
clip = target.clip
tokenizer = target.tokenizer
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
dtype = model_management.text_encoder_dtype(load_device)
load_device = model_options.get("load_device", model_management.text_encoder_device())
offload_device = model_options.get("offload_device", model_management.text_encoder_offload_device())
dtype = model_options.get("dtype", None)
if dtype is None:
dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
params['device'] = model_options.get("initial_device", model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)))
params['model_options'] = model_options
self.cond_stage_model = clip(**(params))
for dt in self.cond_stage_model.dtypes:
@@ -86,7 +93,7 @@ class CLIP:
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
if params['device'] == load_device:
model_management.load_model_gpu(self.patcher)
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
@@ -236,6 +243,13 @@ class VAE:
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd: #genmo mochi vae
if "blocks.2.blocks.3.stack.5.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE()
self.latent_channels = 12
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -291,6 +305,10 @@ class VAE:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
@@ -309,6 +327,7 @@ class VAE:
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
def decode(self, samples_in):
pixel_samples = None
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
@@ -316,16 +335,21 @@ class VAE:
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION as e:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
if len(samples_in.shape) == 3:
dims = samples_in.ndim - 2
if dims == 1:
pixel_samples = self.decode_tiled_1d(samples_in)
else:
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
pixel_samples = self.decode_tiled_3d(samples_in)
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
@@ -342,7 +366,7 @@ class VAE:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number):
@@ -393,11 +417,53 @@ class CLIPType(Enum):
STABLE_AUDIO = 4
HUNYUAN_DIT = 5
FLUX = 6
MOCHI = 7
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
class TEModel(Enum):
CLIP_L = 1
CLIP_H = 2
CLIP_G = 3
T5_XXL = 4
T5_XL = 5
T5_BASE = 6
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
return TEModel.CLIP_G
if "text_model.encoder.layers.22.mlp.fc1.weight" in sd:
return TEModel.CLIP_H
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
return TEModel.CLIP_L
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
if weight.shape[-1] == 4096:
return TEModel.T5_XXL
elif weight.shape[-1] == 2048:
return TEModel.T5_XL
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
return TEModel.T5_BASE
return None
def t5xxl_detect(clip_data):
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
for sd in clip_data:
if weight_name in sd:
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
return {}
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = state_dicts
class EmptyClass:
pass
@@ -412,59 +478,65 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_target = EmptyClass()
clip_target.params = {}
if len(clip_data) == 1:
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
te_model = detect_te_model(clip_data[0])
if te_model == TEModel.CLIP_G:
if clip_type == CLIPType.STABLE_CASCADE:
clip_target.clip = sdxl_clip.StableCascadeClipModel
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
elif clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
elif te_model == TEModel.CLIP_H:
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
dtype_t5 = weight.dtype
if weight.shape[-1] == 4096:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
elif te_model == TEModel.T5_XXL:
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif weight.shape[-1] == 2048:
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
elif te_model == TEModel.T5_XL:
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
elif te_model == TEModel.T5_BASE:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2:
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, **t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HUNYUAN_DIT:
clip_target.clip = comfy.text_encoders.hydit.HyditModel
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
elif clip_type == CLIPType.FLUX:
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None))
dtype_t5 = None
if weight is not None:
dtype_t5 = weight.dtype
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif len(clip_data) == 3:
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
parameters = 0
tokenizer_data = {}
for c in clip_data:
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
@@ -506,14 +578,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
sd = comfy.utils.load_torch_file(ckpt_path)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
clip = None
clipvision = None
vae = None
@@ -530,11 +602,11 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return None
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype)
model_config.custom_operations = model_options.get("custom_operations", None)
unet_dtype = model_options.get("weight_dtype", None)
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
@@ -548,7 +620,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model.load_model_weights(sd, diffusion_model_prefix)
@@ -563,7 +634,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
parameters = comfy.utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
@@ -585,12 +656,13 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
model_management.load_models_gpu([model_patcher], force_full_load=True)
return (model_patcher, clip, vae, clipvision)
def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
dtype = model_options.get("dtype", None)
#Allow loading unets from checkpoint files
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
@@ -599,6 +671,8 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
sd = temp_sd
parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "")
@@ -625,13 +699,21 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype)
if dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
else:
unet_dtype = dtype
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if model_options.get("fp8_optimizations", False):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
@@ -640,24 +722,36 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
logging.info("left over keys in unet: {}".format(left_over))
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path, dtype=None):
def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd, dtype=dtype)
model = load_diffusion_model_state_dict(sd, model_options=model_options)
if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model
def load_unet(unet_path, dtype=None):
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
def load_unet_state_dict(sd, dtype=None):
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
clip_sd = None
load_models = [model]
if clip is not None:
load_models.append(clip.load_model())
clip_sd = clip.get_sd()
vae_sd = None
if vae is not None:
vae_sd = vae.get_sd()
model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
for k in extra_keys:
sd[k] = extra_keys[k]

View File

@@ -75,16 +75,15 @@ class ClipTokenWeightEncoder:
return r
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
"pooled",
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
def __init__(self, device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
@@ -94,8 +93,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with open(textmodel_json_config) as f:
config = json.load(f)
self.operations = comfy.ops.manual_cast
operations = model_options.get("custom_operations", None)
scaled_fp8 = None
if operations is None:
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
else:
operations = comfy.ops.manual_cast
self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
if scaled_fp8 is not None:
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
self.num_layers = self.transformer.num_layers
self.max_length = max_length
@@ -539,6 +551,7 @@ class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -552,8 +565,12 @@ class SD1Tokenizer:
def state_dict(self):
return {}
class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, name=None, **kwargs):
super().__init__()
if name is not None:
@@ -563,7 +580,8 @@ class SD1ClipModel(torch.nn.Module):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
self.dtypes = set()
if dtype is not None:

View File

@@ -3,14 +3,14 @@ import torch
import os
class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
if layer == "penultimate":
layer="hidden"
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
def load_sd(self, sd):
return super().load_sd(sd)
@@ -22,7 +22,8 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
class SDXLTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -38,10 +39,11 @@ class SDXLTokenizer:
return {}
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set([dtype])
def set_clip_options(self, options):
@@ -57,7 +59,8 @@ class SDXLClipModel(torch.nn.Module):
token_weight_pairs_l = token_weight_pairs["l"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return torch.cat([l_out, g_out], dim=-1), g_pooled
cut_to = min(l_out.shape[1], g_out.shape[1])
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -66,8 +69,8 @@ class SDXLClipModel(torch.nn.Module):
return self.clip_l.load_sd(sd)
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options)
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
@@ -79,14 +82,14 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)
def load_sd(self, sd):
return super().load_sd(sd)
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, model_options=model_options)

View File

@@ -10,6 +10,7 @@ import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.genmo
from . import supported_models_base
from . import latent_formats
@@ -181,7 +182,7 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL
memory_usage_factor = 0.7
memory_usage_factor = 0.8
def model_type(self, state_dict, prefix=""):
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
@@ -529,12 +530,11 @@ class SD3(supported_models_base.BASE):
clip_l = True
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_g = True
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
if t5_key in state_dict:
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
if "dtype_t5" in t5_detect:
t5 = True
dtype_t5 = state_dict[t5_key].dtype
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, **t5_detect))
class StableAudio(supported_models_base.BASE):
unet_config = {
@@ -653,10 +653,8 @@ class Flux(supported_models_base.BASE):
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
if t5_key in state_dict:
dtype_t5 = state_dict[t5_key].dtype
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
class FluxSchnell(Flux):
unet_config = {
@@ -673,7 +671,36 @@ class FluxSchnell(Flux):
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
return out
class GenmoMochi(supported_models_base.BASE):
unet_config = {
"image_model": "mochi_preview",
}
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell]
sampling_settings = {
"multiplier": 1.0,
"shift": 6.0,
}
unet_extra_config = {}
latent_format = latent_formats.Mochi
memory_usage_factor = 2.0 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.GenmoMochi(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi]
models += [SVD_img2vid]

View File

@@ -49,6 +49,8 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
scaled_fp8 = None
optimizations = {"fp8": False}
@classmethod
def matches(s, unet_config, state_dict=None):
@@ -71,6 +73,7 @@ class BASE:
self.unet_config = unet_config.copy()
self.sampling_settings = self.sampling_settings.copy()
self.latent_format = self.latent_format()
self.optimizations = self.optimizations.copy()
for x in self.unet_extra_config:
self.unet_config[x] = self.unet_extra_config[x]

View File

@@ -4,9 +4,9 @@ import comfy.text_encoders.t5
import os
class PT5XlModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
class PT5XlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -18,5 +18,5 @@ class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
class AuraT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)

View File

@@ -1,24 +1,21 @@
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip
import comfy.model_management
from transformers import T5TokenizerFast
import torch
import os
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -35,11 +32,12 @@ class FluxTokenizer:
class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5])
def set_clip_options(self, options):
@@ -64,8 +62,11 @@ class FluxClipModel(torch.nn.Module):
else:
return self.t5xxl.load_sd(sd)
def flux_clip(dtype_t5=None):
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype)
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_

View File

@@ -0,0 +1,38 @@
from comfy import sd1_clip
import comfy.text_encoders.sd3_clip
import os
from transformers import T5TokenizerFast
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs):
kwargs["attention_mask"] = True
super().__init__(**kwargs)
class MochiT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class MochiTEModel_(MochiT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return MochiTEModel_

View File

@@ -7,9 +7,9 @@ import os
import torch
class HyditBertModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class HyditBertTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -18,9 +18,9 @@ class HyditBertTokenizer(sd1_clip.SDTokenizer):
class MT5XLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class MT5XLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -50,10 +50,10 @@ class HyditTokenizer:
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
class HyditModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
self.hydit_clip = HyditBertModel(dtype=dtype)
self.mt5xl = MT5XLModel(dtype=dtype)
self.hydit_clip = HyditBertModel(dtype=dtype, model_options=model_options)
self.mt5xl = MT5XLModel(dtype=dtype, model_options=model_options)
self.dtypes = set()
if dtype is not None:

View File

@@ -0,0 +1,25 @@
{
"_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 49407,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 248,
"model_type": "clip_text_model",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"projection_dim": 768,
"torch_dtype": "float32",
"transformers_version": "4.24.0",
"vocab_size": 49408
}

View File

@@ -0,0 +1,30 @@
from comfy import sd1_clip
import os
class LongClipTokenizer_(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
class LongClipModel_(sd1_clip.SDClipModel):
def __init__(self, *args, **kwargs):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
class LongClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
def model_options_long_clip(sd, tokenizer_data, model_options):
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
if w is None:
w = sd.get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
tokenizer_data = tokenizer_data.copy()
model_options = model_options.copy()
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
model_options["clip_l_class"] = LongClipModel_
return tokenizer_data, model_options

View File

@@ -4,9 +4,9 @@ import comfy.text_encoders.t5
import os
class T5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class T5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -18,5 +18,5 @@ class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer)
class SAT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs)
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, name="t5base", clip_model=T5BaseModel, **kwargs)

View File

@@ -2,13 +2,13 @@ from comfy import sd1_clip
import os
class SD2ClipHModel(sd1_clip.SDClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
if layer == "penultimate":
layer="hidden"
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=True, model_options=model_options)
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
@@ -19,5 +19,5 @@ class SD2Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
class SD2ClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="h", clip_model=SD2ClipHModel, **kwargs)

View File

@@ -8,19 +8,38 @@ import comfy.model_management
import logging
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
if t5xxl_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def t5_xxl_detect(state_dict, prefix=""):
out = {}
t5_key = "{}encoder.final_layer_norm.weight".format(prefix)
if t5_key in state_dict:
out["dtype_t5"] = state_dict[t5_key].dtype
scaled_fp8_key = "{}scaled_fp8".format(prefix)
if scaled_fp8_key in state_dict:
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
return out
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
@@ -38,24 +57,26 @@ class SD3Tokenizer:
return {}
class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = set()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_l = None
if clip_g:
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_g = None
if t5:
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.t5_attention_mask = t5_attention_mask
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=self.t5_attention_mask)
self.dtypes.add(dtype_t5)
else:
self.t5xxl = None
@@ -85,6 +106,7 @@ class SD3ClipModel(torch.nn.Module):
lg_out = None
pooled = None
out = None
extra = {}
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
if self.clip_l is not None:
@@ -95,7 +117,8 @@ class SD3ClipModel(torch.nn.Module):
if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
if lg_out is not None:
lg_out = torch.cat([lg_out, g_out], dim=-1)
cut_to = min(lg_out.shape[1], g_out.shape[1])
lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
else:
lg_out = torch.nn.functional.pad(g_out, (768, 0))
else:
@@ -108,7 +131,11 @@ class SD3ClipModel(torch.nn.Module):
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
if self.t5xxl is not None:
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
t5_out, t5_pooled = t5_output[:2]
if self.t5_attention_mask:
extra["attention_mask"] = t5_output[2]["attention_mask"]
if lg_out is not None:
out = torch.cat([lg_out, t5_out], dim=-2)
else:
@@ -120,7 +147,7 @@ class SD3ClipModel(torch.nn.Module):
if pooled is None:
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
return out, pooled
return out, pooled, extra
def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -130,8 +157,11 @@ class SD3ClipModel(torch.nn.Module):
else:
return self.t5xxl.load_sd(sd)
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_

View File

@@ -68,7 +68,7 @@ def weight_dtype(sd, prefix=""):
for k in sd.keys():
if k.startswith(prefix):
w = sd[k]
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
if len(dtypes) == 0:
return None
@@ -528,6 +528,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
}
for k in MAP_BASIC:
@@ -688,9 +690,14 @@ def lanczos(samples, width, height):
return result.to(samples.device, samples.dtype)
def common_upscale(samples, width, height, upscale_method, crop):
orig_shape = tuple(samples.shape)
if len(orig_shape) > 4:
samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
samples = samples.movedim(2, 1)
samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
if crop == "center":
old_width = samples.shape[3]
old_height = samples.shape[2]
old_width = samples.shape[-1]
old_height = samples.shape[-2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
@@ -699,48 +706,87 @@ def common_upscale(samples, width, height, upscale_method, crop):
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples[:,:,y:old_height-y,x:old_width-x]
s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
else:
s = samples
if upscale_method == "bislerp":
return bislerp(s, width, height)
out = bislerp(s, width, height)
elif upscale_method == "lanczos":
return lanczos(s, width, height)
out = lanczos(s, width, height)
else:
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
if len(orig_shape) == 4:
return out
out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
return rows * cols
@torch.inference_mode()
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
dims = len(tile)
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
if not (isinstance(upscale_amount, (tuple, list))):
upscale_amount = [upscale_amount] * dims
if not (isinstance(overlap, (tuple, list))):
overlap = [overlap] * dims
def get_upscale(dim, val):
up = upscale_amount[dim]
if callable(up):
return up(val)
else:
return up * val
def mult_list_upscale(a):
out = []
for i in range(len(a)):
out.append(round(get_upscale(i, a[i])))
return out
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
for b in range(samples.shape[0]):
s = samples[b:b+1]
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
# handle entire input fitting in a single tile
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
output[b:b+1] = function(s).to(output_device)
if pbar is not None:
pbar.update(1)
continue
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
positions = [range(0, s.shape[d+2], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
for it in itertools.product(*positions):
s_in = s
upscaled = []
for d in range(dims):
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
pos = max(0, min(s.shape[d + 2] - (overlap[d] + 1), it[d]))
l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(pos * upscale_amount))
upscaled.append(round(get_upscale(d, pos)))
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount)
for t in range(feather):
for d in range(2, dims + 2):
m = mask.narrow(d, t, 1)
m *= ((1.0/feather) * (t + 1))
m = mask.narrow(d, mask.shape[d] -1 -t, 1)
m *= ((1.0/feather) * (t + 1))
for d in range(2, dims + 2):
feather = round(get_upscale(d - 2, overlap[d - 2]))
for t in range(feather):
a = (t + 1) / feather
mask.narrow(d, t, 1).mul_(a)
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
o = out
o_d = out_div
@@ -748,8 +794,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o += ps * mask
o_d += mask
o.add_(ps * mask)
o_d.add_(mask)
if pbar is not None:
pbar.update(1)

318
comfy_execution/caching.py Normal file
View File

@@ -0,0 +1,318 @@
import itertools
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
import nodes
from comfy_execution.graph_utils import is_link
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
def include_unique_id_in_input(class_type: str) -> bool:
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet:
def __init__(self, dynprompt, node_ids, is_changed_cache):
self.keys = {}
self.subcache_keys = {}
def add_keys(self, node_ids):
raise NotImplementedError()
def all_node_ids(self):
return set(self.keys.keys())
def get_used_keys(self):
return self.keys.values()
def get_used_subcache_keys(self):
return self.subcache_keys.values()
def get_data_key(self, node_id):
return self.keys.get(node_id, None)
def get_subcache_key(self, node_id):
return self.subcache_keys.get(node_id, None)
class Unhashable:
def __init__(self):
self.value = float("NaN")
def to_hashable(obj):
# So that we don't infinitely recurse since frozenset and tuples
# are Sequences.
if isinstance(obj, (int, float, str, bool, type(None))):
return obj
elif isinstance(obj, Mapping):
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
elif isinstance(obj, Sequence):
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
else:
# TODO - Support other objects like tensors?
return Unhashable()
class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.add_keys(node_ids)
def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = (node_id, node["class_type"])
self.subcache_keys[node_id] = (node_id, node["class_type"])
class CacheKeySetInputSignature(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache
self.add_keys(node_ids)
def include_node_id_in_input(self) -> bool:
return False
def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"])
def get_node_signature(self, dynprompt, node_id):
signature = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors:
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return to_hashable(signature)
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it.
return [float("NaN")]
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
inputs = node["inputs"]
for key in sorted(inputs.keys()):
if is_link(inputs[key]):
(ancestor_id, ancestor_socket) = inputs[key]
ancestor_index = ancestor_order_mapping[ancestor_id]
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
else:
signature.append((key, inputs[key]))
return signature
# This function returns a list of all ancestors of the given node. The order of the list is
# deterministic based on which specific inputs the ancestor is connected by.
def get_ordered_ancestry(self, dynprompt, node_id):
ancestors = []
order_mapping = {}
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
return ancestors, order_mapping
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
if not dynprompt.has_node(node_id):
return
inputs = dynprompt.get_node(node_id)["inputs"]
input_keys = sorted(inputs.keys())
for key in input_keys:
if is_link(inputs[key]):
ancestor_id = inputs[key][0]
if ancestor_id not in order_mapping:
ancestors.append(ancestor_id)
order_mapping[ancestor_id] = len(ancestors) - 1
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
class BasicCache:
def __init__(self, key_class):
self.key_class = key_class
self.initialized = False
self.dynprompt: DynamicPrompt
self.cache_key_set: CacheKeySet
self.cache = {}
self.subcaches = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
self.is_changed_cache = is_changed_cache
self.initialized = True
def all_node_ids(self):
assert self.initialized
node_ids = self.cache_key_set.all_node_ids()
for subcache in self.subcaches.values():
node_ids = node_ids.union(subcache.all_node_ids())
return node_ids
def _clean_cache(self):
preserve_keys = set(self.cache_key_set.get_used_keys())
to_remove = []
for key in self.cache:
if key not in preserve_keys:
to_remove.append(key)
for key in to_remove:
del self.cache[key]
def _clean_subcaches(self):
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
to_remove = []
for key in self.subcaches:
if key not in preserve_subcaches:
to_remove.append(key)
for key in to_remove:
del self.subcaches[key]
def clean_unused(self):
assert self.initialized
self._clean_cache()
self._clean_subcaches()
def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
def _get_immediate(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
else:
return None
def _ensure_subcache(self, node_id, children_ids):
subcache_key = self.cache_key_set.get_subcache_key(node_id)
subcache = self.subcaches.get(subcache_key, None)
if subcache is None:
subcache = BasicCache(self.key_class)
self.subcaches[subcache_key] = subcache
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
return subcache
def _get_subcache(self, node_id):
assert self.initialized
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
return self.subcaches[subcache_key]
else:
return None
def recursive_debug_dump(self):
result = []
for key in self.cache:
result.append({"key": key, "value": self.cache[key]})
for key in self.subcaches:
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
return result
class HierarchicalCache(BasicCache):
def __init__(self, key_class):
super().__init__(key_class)
def _get_cache_for(self, node_id):
assert self.dynprompt is not None
parent_id = self.dynprompt.get_parent_node_id(node_id)
if parent_id is None:
return self
hierarchy = []
while parent_id is not None:
hierarchy.append(parent_id)
parent_id = self.dynprompt.get_parent_node_id(parent_id)
cache = self
for parent_id in reversed(hierarchy):
cache = cache._get_subcache(parent_id)
if cache is None:
return None
return cache
def get(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return cache._get_immediate(node_id)
def set(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
cache._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
assert cache is not None
return cache._ensure_subcache(node_id, children_ids)
class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100):
super().__init__(key_class)
self.max_size = max_size
self.min_generation = 0
self.generation = 0
self.used_generation = {}
self.children = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
super().set_prompt(dynprompt, node_ids, is_changed_cache)
self.generation += 1
for node_id in node_ids:
self._mark_used(node_id)
def clean_unused(self):
while len(self.cache) > self.max_size and self.min_generation < self.generation:
self.min_generation += 1
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
for key in to_remove:
del self.cache[key]
del self.used_generation[key]
if key in self.children:
del self.children[key]
self._clean_subcaches()
def get(self, node_id):
self._mark_used(node_id)
return self._get_immediate(node_id)
def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None:
self.used_generation[cache_key] = self.generation
def set(self, node_id, value):
self._mark_used(node_id)
return self._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
super()._ensure_subcache(node_id, children_ids)
self.cache_key_set.add_keys(children_ids)
self._mark_used(node_id)
cache_key = self.cache_key_set.get_data_key(node_id)
self.children[cache_key] = []
for child_id in children_ids:
self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self

270
comfy_execution/graph.py Normal file
View File

@@ -0,0 +1,270 @@
import nodes
from comfy_execution.graph_utils import is_link
class DependencyCycleError(Exception):
pass
class NodeInputError(Exception):
pass
class NodeNotFoundError(Exception):
pass
class DynamicPrompt:
def __init__(self, original_prompt):
# The original prompt provided by the user
self.original_prompt = original_prompt
# Any extra pieces of the graph created during execution
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
self.ephemeral_display = {}
def get_node(self, node_id):
if node_id in self.ephemeral_prompt:
return self.ephemeral_prompt[node_id]
if node_id in self.original_prompt:
return self.original_prompt[node_id]
raise NodeNotFoundError(f"Node {node_id} not found")
def has_node(self, node_id):
return node_id in self.original_prompt or node_id in self.ephemeral_prompt
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
self.ephemeral_prompt[node_id] = node_info
self.ephemeral_parents[node_id] = parent_id
self.ephemeral_display[node_id] = display_id
def get_real_node_id(self, node_id):
while node_id in self.ephemeral_parents:
node_id = self.ephemeral_parents[node_id]
return node_id
def get_parent_node_id(self, node_id):
return self.ephemeral_parents.get(node_id, None)
def get_display_node_id(self, node_id):
while node_id in self.ephemeral_display:
node_id = self.ephemeral_display[node_id]
return node_id
def all_node_ids(self):
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
def get_original_prompt(self):
return self.original_prompt
def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES()
input_info = None
input_category = None
if "required" in valid_inputs and input_name in valid_inputs["required"]:
input_category = "required"
input_info = valid_inputs["required"][input_name]
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
input_category = "optional"
input_info = valid_inputs["optional"][input_name]
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
input_category = "hidden"
input_info = valid_inputs["hidden"][input_name]
if input_info is None:
return None, None, None
input_type = input_info[0]
if len(input_info) > 1:
extra_info = input_info[1]
else:
extra_info = {}
return input_type, input_category, extra_info
class TopologicalSort:
def __init__(self, dynprompt):
self.dynprompt = dynprompt
self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node
def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return get_input_info(class_def, input_name)
def make_input_strong_link(self, to_node_id, to_input):
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
if to_input not in inputs:
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all")
value = inputs[to_input]
if not is_link(value):
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant")
from_node_id, from_socket = value
self.add_strong_link(from_node_id, from_socket, to_node_id)
def add_strong_link(self, from_node_id, from_socket, to_node_id):
if not self.is_cached(from_node_id):
self.add_node(from_node_id)
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
node_ids = [node_unique_id]
links = []
while len(node_ids) > 0:
unique_id = node_ids.pop()
if unique_id in self.pendingNodes:
continue
self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}
inputs = self.dynprompt.get_node(unique_id)["inputs"]
for input_name in inputs:
value = inputs[input_name]
if is_link(value):
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id))
for link in links:
self.add_strong_link(*link)
def is_cached(self, node_id):
return False
def get_ready_nodes(self):
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
def pop_node(self, unique_id):
del self.pendingNodes[unique_id]
for blocked_node_id in self.blocking[unique_id]:
self.blockCount[blocked_node_id] -= 1
del self.blocking[unique_id]
def is_empty(self):
return len(self.pendingNodes) == 0
class ExecutionList(TopologicalSort):
"""
ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
it can still be returned to the graph after having further dependencies added.
"""
def __init__(self, dynprompt, output_cache):
super().__init__(dynprompt)
self.output_cache = output_cache
self.staged_node_id = None
def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None
def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
return None, None, None
available = self.get_ready_nodes()
if len(available) == 0:
cycled_nodes = self.get_nodes_in_cycle()
# Because cycles composed entirely of static nodes are caught during initial validation,
# we will 'blame' the first node in the cycle that is not a static node.
blamed_node = cycled_nodes[0]
for node_id in cycled_nodes:
display_node_id = self.dynprompt.get_display_node_id(node_id)
if display_node_id != node_id:
blamed_node = display_node_id
break
ex = DependencyCycleError("Dependency cycle detected")
error_details = {
"node_id": blamed_node,
"exception_message": str(ex),
"exception_type": "graph.DependencyCycleError",
"traceback": [],
"current_inputs": []
}
return None, error_details, ex
self.staged_node_id = self.ux_friendly_pick_node(available)
return self.staged_node_id, None, None
def ux_friendly_pick_node(self, node_list):
# If an output node is available, do that first.
# Technically this has no effect on the overall length of execution, but it feels better as a user
# for a PreviewImage to display a result as soon as it can
# Some other heuristics could probably be used here to improve the UX further.
def is_output(node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
return True
return False
for node_id in node_list:
if is_output(node_id):
return node_id
#This should handle the VAEDecode -> preview case
for node_id in node_list:
for blocked_node_id in self.blocking[node_id]:
if is_output(blocked_node_id):
return node_id
#This should handle the VAELoader -> VAEDecode -> preview case
for node_id in node_list:
for blocked_node_id in self.blocking[node_id]:
for blocked_node_id1 in self.blocking[blocked_node_id]:
if is_output(blocked_node_id1):
return node_id
#TODO: this function should be improved
return node_list[0]
def unstage_node_execution(self):
assert self.staged_node_id is not None
self.staged_node_id = None
def complete_node_execution(self):
node_id = self.staged_node_id
self.pop_node(node_id)
self.staged_node_id = None
def get_nodes_in_cycle(self):
# We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
# We're skipping some of the performance optimizations from the original TopologicalSort to keep
# the code simple (and because having a cycle in the first place is a catastrophic error)
blocked_by = { node_id: {} for node_id in self.pendingNodes }
for from_node_id in self.blocking:
for to_node_id in self.blocking[from_node_id]:
if True in self.blocking[from_node_id][to_node_id].values():
blocked_by[to_node_id][from_node_id] = True
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
while len(to_remove) > 0:
for node_id in to_remove:
for to_node_id in blocked_by:
if node_id in blocked_by[to_node_id]:
del blocked_by[to_node_id][node_id]
del blocked_by[node_id]
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
return list(blocked_by.keys())
class ExecutionBlocker:
"""
Return this from a node and any users will be blocked with the given error message.
If the message is None, execution will be blocked silently instead.
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
possible, a lazy input will be more efficient and have a better user experience.
This functionality is useful in two cases:
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
lazy evaluation to let it conditionally disable itself.)
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
"""
def __init__(self, message):
self.message = message

View File

@@ -0,0 +1,139 @@
def is_link(obj):
if not isinstance(obj, list):
return False
if len(obj) != 2:
return False
if not isinstance(obj[0], str):
return False
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
return False
return True
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
class GraphBuilder:
_default_prefix_root = ""
_default_prefix_call_index = 0
_default_prefix_graph_index = 0
def __init__(self, prefix = None):
if prefix is None:
self.prefix = GraphBuilder.alloc_prefix()
else:
self.prefix = prefix
self.nodes = {}
self.id_gen = 1
@classmethod
def set_default_prefix(cls, prefix_root, call_index, graph_index = 0):
cls._default_prefix_root = prefix_root
cls._default_prefix_call_index = call_index
cls._default_prefix_graph_index = graph_index
@classmethod
def alloc_prefix(cls, root=None, call_index=None, graph_index=None):
if root is None:
root = GraphBuilder._default_prefix_root
if call_index is None:
call_index = GraphBuilder._default_prefix_call_index
if graph_index is None:
graph_index = GraphBuilder._default_prefix_graph_index
result = f"{root}.{call_index}.{graph_index}."
GraphBuilder._default_prefix_graph_index += 1
return result
def node(self, class_type, id=None, **kwargs):
if id is None:
id = str(self.id_gen)
self.id_gen += 1
id = self.prefix + id
if id in self.nodes:
return self.nodes[id]
node = Node(id, class_type, kwargs)
self.nodes[id] = node
return node
def lookup_node(self, id):
id = self.prefix + id
return self.nodes.get(id)
def finalize(self):
output = {}
for node_id, node in self.nodes.items():
output[node_id] = node.serialize()
return output
def replace_node_output(self, node_id, index, new_value):
node_id = self.prefix + node_id
to_remove = []
for node in self.nodes.values():
for key, value in node.inputs.items():
if is_link(value) and value[0] == node_id and value[1] == index:
if new_value is None:
to_remove.append((node, key))
else:
node.inputs[key] = new_value
for node, key in to_remove:
del node.inputs[key]
def remove_node(self, id):
id = self.prefix + id
del self.nodes[id]
class Node:
def __init__(self, id, class_type, inputs):
self.id = id
self.class_type = class_type
self.inputs = inputs
self.override_display_id = None
def out(self, index):
return [self.id, index]
def set_input(self, key, value):
if value is None:
if key in self.inputs:
del self.inputs[key]
else:
self.inputs[key] = value
def get_input(self, key):
return self.inputs.get(key)
def set_override_display_id(self, override_display_id):
self.override_display_id = override_display_id
def serialize(self):
serialized = {
"class_type": self.class_type,
"inputs": self.inputs
}
if self.override_display_id is not None:
serialized["override_display_id"] = self.override_display_id
return serialized
def add_graph_prefix(graph, outputs, prefix):
# Change the node IDs and any internal links
new_graph = {}
for node_id, node_info in graph.items():
# Make sure the added nodes have unique IDs
new_node_id = prefix + node_id
new_node = { "class_type": node_info["class_type"], "inputs": {} }
for input_name, input_value in node_info.get("inputs", {}).items():
if is_link(input_value):
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
else:
new_node["inputs"][input_name] = input_value
new_graph[new_node_id] = new_node
# Change the node IDs in the outputs
new_outputs = []
for n in range(len(outputs)):
output = outputs[n]
if is_link(output):
new_outputs.append([prefix + output[0], output[1]])
else:
new_outputs.append(output)
return new_graph, tuple(new_outputs)

View File

@@ -16,14 +16,15 @@ class EmptyLatentAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}}
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/audio"
def generate(self, seconds):
batch_size = 1
def generate(self, seconds, batch_size):
length = round((seconds * 44100 / 2048) / 2) * 2
latent = torch.zeros([batch_size, 64, length], device=self.device)
return ({"samples":latent, "type": "audio"}, )
@@ -58,6 +59,9 @@ class VAEDecodeAudio:
def decode(self, vae, samples):
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
return ({"waveform": audio, "sample_rate": 44100}, )
@@ -183,17 +187,10 @@ class PreviewAudio(SaveAudio):
}
class LoadAudio:
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [
f for f in os.listdir(input_dir)
if (os.path.isfile(os.path.join(input_dir, f))
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
)
]
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
CATEGORY = "audio"

View File

@@ -1,4 +1,6 @@
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
import nodes
import comfy.utils
class SetUnionControlNetType:
@classmethod
@@ -22,6 +24,37 @@ class SetUnionControlNetType:
return (control_net,)
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"vae": ("VAE", ),
"image": ("IMAGE", ),
"mask": ("MASK", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
FUNCTION = "apply_inpaint_controlnet"
CATEGORY = "conditioning/controlnet"
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
extra_concat = []
if control_net.concat_mask:
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
extra_concat = [mask]
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
NODE_CLASS_MAPPINGS = {
"SetUnionControlNetType": SetUnionControlNetType,
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
}

View File

@@ -90,6 +90,27 @@ class PolyexponentialScheduler:
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
return (sigmas, )
class LaplaceScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
"beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
return (sigmas, )
class SDTurboScheduler:
@classmethod
def INPUT_TYPES(s):
@@ -673,6 +694,7 @@ NODE_CLASS_MAPPINGS = {
"KarrasScheduler": KarrasScheduler,
"ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler,
"LaplaceScheduler": LaplaceScheduler,
"VPScheduler": VPScheduler,
"BetaSamplingScheduler": BetaSamplingScheduler,
"SDTurboScheduler": SDTurboScheduler,

View File

@@ -107,7 +107,7 @@ class HypernetworkLoader:
CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None:

View File

@@ -1,4 +1,5 @@
import comfy.utils
import comfy_extras.nodes_post_processing
import torch
def reshape_latent_to(target_shape, latent):
@@ -145,6 +146,131 @@ class LatentBatchSeedBehavior:
return (samples_out,)
class LatentApplyOperation:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"operation": ("LATENT_OPERATION",),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, samples, operation):
samples_out = samples.copy()
s1 = samples["samples"]
samples_out["samples"] = operation(latent=s1)
return (samples_out,)
class LatentApplyOperationCFG:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"operation": ("LATENT_OPERATION",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def patch(self, model, operation):
m = model.clone()
def pre_cfg_function(args):
conds_out = args["conds_out"]
if len(conds_out) == 2:
conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1]
else:
conds_out[0] = operation(latent=conds_out[0])
return conds_out
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
return (m, )
class LatentOperationTonemapReinhard:
@classmethod
def INPUT_TYPES(s):
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("LATENT_OPERATION",)
FUNCTION = "op"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, multiplier):
def tonemap_reinhard(latent, **kwargs):
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
normalized_latent = latent / latent_vector_magnitude
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
top = (std * 5 + mean) * multiplier
#reinhard
latent_vector_magnitude *= (1.0 / top)
new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0)
new_magnitude *= top
return normalized_latent * new_magnitude
return (tonemap_reinhard,)
class LatentOperationSharpen:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"sharpen_radius": ("INT", {
"default": 9,
"min": 1,
"max": 31,
"step": 1
}),
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
}),
"alpha": ("FLOAT", {
"default": 0.1,
"min": 0.0,
"max": 5.0,
"step": 0.01
}),
}}
RETURN_TYPES = ("LATENT_OPERATION",)
FUNCTION = "op"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, sharpen_radius, sigma, alpha):
def sharpen(latent, **kwargs):
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
normalized_latent = latent / luminance
channels = latent.shape[1]
kernel_size = sharpen_radius * 2 + 1
kernel = comfy_extras.nodes_post_processing.gaussian_kernel(kernel_size, sigma, device=luminance.device)
center = kernel_size // 2
kernel *= alpha * -10
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
padded_image = torch.nn.functional.pad(normalized_latent, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
return luminance * sharpened
return (sharpen,)
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
@@ -152,4 +278,8 @@ NODE_CLASS_MAPPINGS = {
"LatentInterpolate": LatentInterpolate,
"LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
"LatentApplyOperation": LatentApplyOperation,
"LatentApplyOperationCFG": LatentApplyOperationCFG,
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
"LatentOperationSharpen": LatentOperationSharpen,
}

View File

@@ -0,0 +1,119 @@
import torch
import comfy.model_management
import comfy.utils
import folder_paths
import os
import logging
from enum import Enum
CLAMP_QUANTILE = 0.99
def extract_lora(diff, rank):
conv2d = (len(diff.shape) == 4)
kernel_size = None if not conv2d else diff.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = diff.size()[0:2]
rank = min(rank, in_dim, out_dim)
if conv2d:
if conv2d_3x3:
diff = diff.flatten(start_dim=1)
else:
diff = diff.squeeze()
U, S, Vh = torch.linalg.svd(diff.float())
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
return (U, Vh)
class LORAType(Enum):
STANDARD = 0
FULL_DIFF = 1
LORA_TYPES = {"standard": LORAType.STANDARD,
"full_diff": LORAType.FULL_DIFF}
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
for k in sd:
if k.endswith(".weight"):
weight_diff = sd[k]
if lora_type == LORAType.STANDARD:
if weight_diff.ndim < 2:
if bias_diff:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
continue
try:
out = extract_lora(weight_diff, rank)
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
except:
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
elif lora_type == LORAType.FULL_DIFF:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
elif bias_diff and k.endswith(".bias"):
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
return output_sd
class LoraSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
"lora_type": (tuple(LORA_TYPES.keys()),),
"bias_diff": ("BOOLEAN", {"default": True}),
},
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
if model_diff is None and text_encoder_diff is None:
return {}
lora_type = LORA_TYPES.get(lora_type)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
output_sd = {}
if model_diff is not None:
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
if text_encoder_diff is not None:
output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
return {}
NODE_CLASS_MAPPINGS = {
"LoraSave": LoraSave
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoraSave": "Extract and Save Lora"
}

View File

@@ -0,0 +1,26 @@
import nodes
import torch
import comfy.model_management
class EmptyMochiLatentVideo:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/mochi"
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
NODE_CLASS_MAPPINGS = {
"EmptyMochiLatentVideo": EmptyMochiLatentVideo,
}

View File

@@ -17,7 +17,7 @@ class PatchModelAddDownscale:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
CATEGORY = "model_patches/unet"
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
model_sampling = model.get_model_object("model_sampling")

View File

@@ -333,6 +333,25 @@ class VAESave:
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
return {}
class ModelSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple,
"ModelMergeBlocks": ModelMergeBlocks,
@@ -344,4 +363,9 @@ NODE_CLASS_MAPPINGS = {
"CLIPMergeAdd": CLIPAdd,
"CLIPSave": CLIPSave,
"VAESave": VAESave,
"ModelSave": ModelSave,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointSave": "Save Checkpoint",
}

View File

@@ -101,10 +101,34 @@ class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embed."] = argument
arg_dict["x_embedder."] = argument
arg_dict["context_embedder."] = argument
arg_dict["y_embedder."] = argument
arg_dict["t_embedder."] = argument
for i in range(38):
arg_dict["joint_blocks.{}.".format(i)] = argument
arg_dict["final_layer."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
"ModelMergeSDXL": ModelMergeSDXL,
"ModelMergeSD3_2B": ModelMergeSD3_2B,
"ModelMergeFlux1": ModelMergeFlux1,
"ModelMergeSD35_Large": ModelMergeSD35_Large,
}

View File

@@ -26,6 +26,7 @@ class PerpNeg:
FUNCTION = "patch"
CATEGORY = "_for_testing"
DEPRECATED = True
def patch(self, model, empty_conditioning, neg_scale):
m = model.clone()

View File

@@ -126,7 +126,7 @@ class PhotoMakerLoader:
CATEGORY = "_for_testing/photomaker"
def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder()
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data:

View File

@@ -3,7 +3,7 @@ import comfy.sd
import comfy.model_management
import nodes
import torch
import re
class TripleCLIPLoader:
@classmethod
def INPUT_TYPES(s):
@@ -15,9 +15,9 @@ class TripleCLIPLoader:
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip_path3 = folder_paths.get_full_path("clip", clip_name3)
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("clip", clip_name3)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
@@ -36,7 +36,7 @@ class EmptySD3LatentImage:
CATEGORY = "latent/sd3"
def generate(self, width, height, batch_size=1):
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
class CLIPTextEncodeSD3:
@@ -93,15 +93,75 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
CATEGORY = "conditioning/controlnet"
DEPRECATED = True
class SkipLayerGuidanceSD3:
'''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Experimental implementation by Dango233@StabilityAI.
'''
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "skip_guidance"
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, layers, scale, start_percent, end_percent):
if layers == "" or layers == None:
return (model, )
# check if layer is comma separated integers
def skip(args, extra_args):
return args
model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
def post_cfg_function(args):
model = args["model"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
x = args["input"]
model_options = args["model_options"].copy()
for layer in layers:
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
model_sampling.percent_to_sigma(start_percent)
sigma_ = sigma[0].item()
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
cfg_result = cfg_result + (cond_pred - slg) * scale
return cfg_result
layers = re.findall(r'\d+', layers)
layers = [int(i) for i in layers]
m = model.clone()
m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m, )
NODE_CLASS_MAPPINGS = {
"TripleCLIPLoader": TripleCLIPLoader,
"EmptySD3LatentImage": EmptySD3LatentImage,
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
"ControlNetApplySD3": ControlNetApplySD3,
"SkipLayerGuidanceSD3": SkipLayerGuidanceSD3,
}
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
"ControlNetApplySD3": "Apply Controlnet with VAE",
}

View File

@@ -116,6 +116,7 @@ class StableCascade_SuperResolutionControlnet:
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
FUNCTION = "generate"
EXPERIMENTAL = True
CATEGORY = "_for_testing/stable_cascade"
def generate(self, image, vae):

View File

@@ -154,7 +154,7 @@ class TomePatchModel:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
CATEGORY = "model_patches/unet"
def patch(self, model, ratio):
self.u = None

View File

@@ -0,0 +1,22 @@
import torch
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"backend": (["inductor", "cudagraphs"],),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model, backend):
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
return (m, )
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}

View File

@@ -25,7 +25,7 @@ class UpscaleModelLoader:
CATEGORY = "loaders"
def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name)
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})

View File

@@ -17,7 +17,7 @@ class ImageOnlyCheckpointLoader:
CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2])
@@ -107,7 +107,7 @@ class VideoTriangleCFGGuidance:
return (m, )
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
CATEGORY = "_for_testing"
CATEGORY = "advanced/model_merging"
@classmethod
def INPUT_TYPES(s):

View File

@@ -4,14 +4,14 @@ class Example:
Class methods
-------------
INPUT_TYPES (dict):
INPUT_TYPES (dict):
Tell the main program input parameters of nodes.
IS_CHANGED:
optional method to control when the node is re executed.
Attributes
----------
RETURN_TYPES (`tuple`):
RETURN_TYPES (`tuple`):
The type of each element in the output tuple.
RETURN_NAMES (`tuple`):
Optional: The name of each output in the output tuple.
@@ -23,13 +23,19 @@ class Example:
Assumed to be False if not present.
CATEGORY (`str`):
The category the node should appear in the UI.
DEPRECATED (`bool`):
Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain
functional in existing workflows that use them.
EXPERIMENTAL (`bool`):
Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to
significant changes or removal in future versions. Use with caution in production workflows.
execute(s) -> tuple || None:
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
"""
@@ -54,7 +60,8 @@ class Example:
"min": 0, #Minimum value
"max": 4096, #Maximum value
"step": 64, #Slider's step
"display": "number" # Cosmetic only: display as "number" or "slider"
"display": "number", # Cosmetic only: display as "number" or "slider"
"lazy": True # Will only be evaluated if check_lazy_status requires it
}),
"float_field": ("FLOAT", {
"default": 1.0,
@@ -62,11 +69,14 @@ class Example:
"max": 10.0,
"step": 0.01,
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
"display": "number"}),
"display": "number",
"lazy": True
}),
"print_to_screen": (["enable", "disable"],),
"string_field": ("STRING", {
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
"default": "Hello World!"
"default": "Hello World!",
"lazy": True
}),
},
}
@@ -80,6 +90,23 @@ class Example:
CATEGORY = "Example"
def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen):
"""
Return a list of input names that need to be evaluated.
This function will be called if there are any lazy inputs which have not yet been
evaluated. As long as you return at least one field which has not yet been evaluated
(and more exist), this function will be called again once the value of the requested
field is available.
Any evaluated inputs will be passed as arguments to this function. Any unevaluated
inputs will have the value None.
"""
if print_to_screen == "enable":
return ["int_field", "float_field", "string_field"]
else:
return []
def test(self, image, string_field, int_field, float_field, print_to_screen):
if print_to_screen == "enable":
print(f"""Your input contains:

View File

@@ -37,6 +37,7 @@ class SaveImageWebsocket:
return {}
@classmethod
def IS_CHANGED(s, images):
return time.time()

View File

@@ -5,6 +5,7 @@ import threading
import heapq
import time
import traceback
from enum import Enum
import inspect
from typing import List, Literal, NamedTuple, Optional
@@ -12,102 +13,225 @@ import torch
import nodes
import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy.cli_args import args
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
class ExecutionResult(Enum):
SUCCESS = 0
FAILURE = 1
PENDING = 2
class DuplicateNodeError(Exception):
pass
class IsChangedCache:
def __init__(self, dynprompt, outputs_cache):
self.dynprompt = dynprompt
self.outputs_cache = outputs_cache
self.is_changed = {}
def get(self, node_id):
if node_id in self.is_changed:
return self.is_changed[node_id]
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if not hasattr(class_def, "IS_CHANGED"):
self.is_changed[node_id] = False
return self.is_changed[node_id]
if "is_changed" in node:
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e:
logging.warning("WARNING: {}".format(e))
node["is_changed"] = float("NaN")
finally:
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
class CacheSet:
def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0:
self.init_classic_cache()
else:
self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects]
# Useful for those with ample RAM/VRAM -- allows experimenting without
# blowing away the cache every time
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
}
return result
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
missing_keys = {}
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_type, input_category, input_info = get_input_info(class_def, x)
def mark_missing():
missing_keys[x] = True
input_data_all[x] = (None,)
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
input_data_all[x] = (None,)
if outputs is None:
mark_missing()
continue # This might be a lazily-evaluated input
cached_output = outputs.get(input_unique_id)
if cached_output is None:
mark_missing()
continue
obj = outputs[input_unique_id][output_index]
if output_index >= len(cached_output):
mark_missing()
continue
obj = cached_output[output_index]
input_data_all[x] = obj
else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = [input_data]
elif input_category is not None:
input_data_all[x] = [input_data]
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
for x in h:
if h[x] == "PROMPT":
input_data_all[x] = [prompt]
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
if h[x] == "DYNPROMPT":
input_data_all[x] = [dynprompt]
if h[x] == "EXTRA_PNGINFO":
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id]
return input_data_all
return input_data_all, missing_keys
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
map_node_over_list = None #Don't hook this please
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# check if node wants the lists
input_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
input_is_list = obj.INPUT_IS_LIST
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
if len(input_data_all) == 0:
max_len_input = 0
else:
max_len_input = max([len(x) for x in input_data_all.values()])
max_len_input = max(len(x) for x in input_data_all.values())
# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
d_new = dict()
for k,v in d.items():
d_new[k] = v[i if len(v) > i else -1]
return d_new
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
results = []
def process_inputs(inputs, index=None):
if allow_interrupt:
nodes.before_node_execution()
execution_block = None
for k, v in inputs.items():
if isinstance(v, ExecutionBlocker):
execution_block = execution_block_cb(v) if execution_block_cb else v
break
if execution_block is None:
if pre_execute_cb is not None and index is not None:
pre_execute_cb(index)
results.append(getattr(obj, func)(**inputs))
else:
results.append(execution_block)
if input_is_list:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all))
process_inputs(input_data_all, 0)
elif max_len_input == 0:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)())
else:
process_inputs({})
else:
for i in range(max_len_input):
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
input_dict = slice_dict(input_data_all, i)
process_inputs(input_dict, i)
return results
def get_output_data(obj, input_data_all):
def merge_result_data(results, obj):
# check which outputs need concatenating
output = []
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
value = []
for o in results:
if isinstance(o[i], ExecutionBlocker):
value.append(o[i])
else:
value.extend(o[i])
output.append(value)
else:
output.append([o[i] for o in results])
return output
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
results = []
uis = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
for r in return_values:
subgraph_results = []
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
has_subgraph = False
for i in range(len(return_values)):
r = return_values[i]
if isinstance(r, dict):
if 'ui' in r:
uis.append(r['ui'])
if 'result' in r:
results.append(r['result'])
if 'expand' in r:
# Perform an expansion, but do not append results
has_subgraph = True
new_graph = r['expand']
result = r.get("result", None)
if isinstance(result, ExecutionBlocker):
result = tuple([result] * len(obj.RETURN_TYPES))
subgraph_results.append((new_graph, result))
elif 'result' in r:
result = r.get("result", None)
if isinstance(result, ExecutionBlocker):
result = tuple([result] * len(obj.RETURN_TYPES))
results.append(result)
subgraph_results.append((None, result))
else:
if isinstance(r, ExecutionBlocker):
r = tuple([r] * len(obj.RETURN_TYPES))
results.append(r)
subgraph_results.append((None, r))
output = []
if len(results) > 0:
# check which outputs need concatenating
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
else:
output.append([o[i] for o in results])
if has_subgraph:
output = subgraph_results
elif len(results) > 0:
output = merge_result_data(results, obj)
else:
output = []
ui = dict()
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui
return output, ui, has_subgraph
def format_value(x):
if x is None:
@@ -117,53 +241,145 @@ def format_value(x):
else:
return str(x)
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
parent_node_id = dynprompt.get_parent_node_id(unique_id)
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if unique_id in outputs:
return (True, None, None)
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
if result[0] is not True:
# Another node failed further upstream
return result
if caches.outputs.get(unique_id) is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
return (ExecutionResult.SUCCESS, None, None)
input_data_all = None
try:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None:
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
if unique_id in pending_subgraph_results:
cached_results = pending_subgraph_results[unique_id]
resolved_outputs = []
for is_subgraph, result in cached_results:
if not is_subgraph:
resolved_outputs.append(result)
else:
resolved_output = []
for r in result:
if is_link(r):
source_node, source_output = r[0], r[1]
node_output = caches.outputs.get(source_node)[source_output]
for o in node_output:
resolved_output.append(o)
obj = object_storage.get((unique_id, class_type), None)
if obj is None:
obj = class_def()
object_storage[(unique_id, class_type)] = obj
output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = output_data
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
else:
resolved_output.append(r)
resolved_outputs.append(tuple(resolved_output))
output_data = merge_result_data(resolved_outputs, class_def)
output_ui = []
has_subgraph = False
else:
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
obj = caches.objects.get(unique_id)
if obj is None:
obj = class_def()
caches.objects.set(unique_id, obj)
if hasattr(obj, "check_lazy_status"):
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
x not in input_data_all or x in missing_keys
)]
if len(required_inputs) > 0:
for i in required_inputs:
execution_list.make_input_strong_link(unique_id, i)
return (ExecutionResult.PENDING, None, None)
def execution_block_cb(block):
if block.message is not None:
mes = {
"prompt_id": prompt_id,
"node_id": unique_id,
"node_type": class_type,
"executed": list(executed),
"exception_message": f"Execution Blocked: {block.message}",
"exception_type": "ExecutionBlocked",
"traceback": [],
"current_inputs": [],
"current_outputs": [],
}
server.send_sync("execution_error", mes, server.client_id)
return ExecutionBlocker(None)
else:
return block
def pre_execute_cb(call_index):
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
"meta": {
"node_id": unique_id,
"display_node": display_node_id,
"parent_node": parent_node_id,
"real_node_id": real_node_id,
},
"output": output_ui
})
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph:
cached_outputs = []
new_node_ids = []
new_output_ids = []
new_output_links = []
for i in range(len(output_data)):
new_graph, node_outputs = output_data[i]
if new_graph is None:
cached_outputs.append((False, node_outputs))
else:
# Check for conflicts
for node_id in new_graph.keys():
if dynprompt.has_node(node_id):
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
for node_id, node_info in new_graph.items():
new_node_ids.append(node_id)
display_id = node_info.get("override_display_id", unique_id)
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
# Figure out if the newly created node is an output node
class_type = node_info["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
new_output_ids.append(node_id)
for i in range(len(node_outputs)):
if is_link(node_outputs[i]):
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
new_output_links.append((from_node_id, from_socket))
cached_outputs.append((True, node_outputs))
new_node_ids = set(new_node_ids)
for cache in caches.all:
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
for node_id in new_output_ids:
execution_list.add_node(node_id)
for link in new_output_links:
execution_list.add_strong_link(link[0], link[1], unique_id)
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data)
except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")
# skip formatting inputs/outputs
error_details = {
"node_id": unique_id,
"node_id": real_node_id,
}
return (False, error_details, iex)
return (ExecutionResult.FAILURE, error_details, iex)
except Exception as ex:
typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ)
@@ -173,121 +389,36 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
for name, inputs in input_data_all.items():
input_data_formatted[name] = [format_value(x) for x in inputs]
output_data_formatted = {}
for node_id, node_outputs in outputs.items():
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
logging.error(f"!!! Exception during processing!!! {ex}")
logging.error(f"!!! Exception during processing !!! {ex}")
logging.error(traceback.format_exc())
error_details = {
"node_id": unique_id,
"node_id": real_node_id,
"exception_message": str(ex),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"current_inputs": input_data_formatted,
"current_outputs": output_data_formatted
"current_inputs": input_data_formatted
}
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
return (False, error_details, ex)
return (ExecutionResult.FAILURE, error_details, ex)
executed.add(unique_id)
return (True, None, None)
def recursive_will_execute(prompt, outputs, current_item, memo={}):
unique_id = current_item
if unique_id in memo:
return memo[unique_id]
inputs = prompt[unique_id]['inputs']
will_execute = []
if unique_id in outputs:
return []
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)
memo[unique_id] = will_execute + [unique_id]
return memo[unique_id]
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
is_changed_old = ''
is_changed = ''
to_delete = False
if hasattr(class_def, 'IS_CHANGED'):
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
is_changed_old = old_prompt[unique_id]['is_changed']
if 'is_changed' not in prompt[unique_id]:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
if input_data_all is not None:
try:
#is_changed = class_def.IS_CHANGED(**input_data_all)
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
prompt[unique_id]['is_changed'] = is_changed
except:
to_delete = True
else:
is_changed = prompt[unique_id]['is_changed']
if unique_id not in outputs:
return True
if not to_delete:
if is_changed != is_changed_old:
to_delete = True
elif unique_id not in old_prompt:
to_delete = True
elif class_type != old_prompt[unique_id]['class_type']:
to_delete = True
elif inputs == old_prompt[unique_id]['inputs']:
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id in outputs:
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
else:
to_delete = True
if to_delete:
break
else:
to_delete = True
if to_delete:
d = outputs.pop(unique_id)
del d
return to_delete
return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor:
def __init__(self, server):
def __init__(self, server, lru_size=None):
self.lru_size = lru_size
self.server = server
self.reset()
def reset(self):
self.outputs = {}
self.object_storage = {}
self.outputs_ui = {}
self.caches = CacheSet(self.lru_size)
self.status_messages = []
self.success = True
self.old_prompt = {}
def add_message(self, event, data: dict, broadcast: bool):
data = {
@@ -318,26 +449,13 @@ class PromptExecutor:
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
"exception_message": error["exception_message"],
"exception_type": error["exception_type"],
"traceback": error["traceback"],
"current_inputs": error["current_inputs"],
"current_outputs": error["current_outputs"],
"current_outputs": list(current_outputs),
}
self.add_message("execution_error", mes, broadcast=False)
# Next, remove the subsequent outputs since they will not be executed
to_delete = []
for o in self.outputs:
if (o not in current_outputs) and (o not in executed):
to_delete += [o]
if o in self.old_prompt:
d = self.old_prompt.pop(o)
del d
for o in to_delete:
d = self.outputs.pop(o)
del d
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False)
@@ -351,65 +469,59 @@ class PromptExecutor:
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode():
#delete cached outputs if nodes don't exist for them
to_delete = []
for o in self.outputs:
if o not in prompt:
to_delete += [o]
for o in to_delete:
d = self.outputs.pop(o)
del d
to_delete = []
for o in self.object_storage:
if o[0] not in prompt:
to_delete += [o]
else:
p = prompt[o[0]]
if o[1] != p['class_type']:
to_delete += [o]
for o in to_delete:
d = self.object_storage.pop(o)
del d
dynamic_prompt = DynamicPrompt(prompt)
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused()
for x in prompt:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
current_outputs = set(self.outputs.keys())
for x in list(self.outputs_ui.keys()):
if x not in current_outputs:
d = self.outputs_ui.pop(x)
del d
cached_nodes = []
for node_id in prompt:
if self.caches.outputs.get(node_id) is not None:
cached_nodes.append(node_id)
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached",
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {}
executed = set()
output_node_id = None
to_execute = []
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs):
to_execute += [(0, node_id)]
execution_list.add_node(node_id)
while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
memo = {}
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
output_node_id = to_execute.pop(0)[-1]
# This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the
# error was raised
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
if self.success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
while not execution_list.is_empty():
node_id, error, ex = execution_list.stage_node_execution()
if error is not None:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
elif result == ExecutionResult.PENDING:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
ui_outputs = {}
meta_outputs = {}
all_node_ids = self.caches.ui.all_node_ids()
for node_id in all_node_ids:
ui_info = self.caches.ui.get(node_id)
if ui_info is not None:
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
}
self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
@@ -426,31 +538,37 @@ def validate_inputs(prompt, item, validated):
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
class_inputs = obj_class.INPUT_TYPES()
required_inputs = class_inputs['required']
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
errors = []
valid = True
validate_function_inputs = []
validate_has_kwargs = False
if hasattr(obj_class, "VALIDATE_INPUTS"):
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
validate_function_inputs = argspec.args
validate_has_kwargs = argspec.varkw is not None
received_types = {}
for x in required_inputs:
for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x)
assert extra_info is not None
if x not in inputs:
error = {
"type": "required_input_missing",
"message": "Required input is missing",
"details": f"{x}",
"extra_info": {
"input_name": x
if input_category == "required":
error = {
"type": "required_input_missing",
"message": "Required input is missing",
"details": f"{x}",
"extra_info": {
"input_name": x
}
}
}
errors.append(error)
errors.append(error)
continue
val = inputs[x]
info = required_inputs[x]
type_input = info[0]
info = (type_input, extra_info)
if isinstance(val, list):
if len(val) != 2:
error = {
@@ -469,8 +587,9 @@ def validate_inputs(prompt, item, validated):
o_id = val[0]
o_class_type = prompt[o_id]['class_type']
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
if r[val[1]] != type_input:
received_type = r[val[1]]
received_type = r[val[1]]
received_types[x] = received_type
if 'input_types' not in validate_function_inputs and received_type != type_input:
details = f"{x}, {received_type} != {type_input}"
error = {
"type": "return_type_mismatch",
@@ -521,6 +640,9 @@ def validate_inputs(prompt, item, validated):
if type_input == "STRING":
val = str(val)
inputs[x] = val
if type_input == "BOOLEAN":
val = bool(val)
inputs[x] = val
except Exception as ex:
error = {
"type": "invalid_input_type",
@@ -536,11 +658,11 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if len(info) > 1:
if "min" in info[1] and val < info[1]["min"]:
if x not in validate_function_inputs and not validate_has_kwargs:
if "min" in extra_info and val < extra_info["min"]:
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
"message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
@@ -550,10 +672,10 @@ def validate_inputs(prompt, item, validated):
}
errors.append(error)
continue
if "max" in info[1] and val > info[1]["max"]:
if "max" in extra_info and val > extra_info["max"]:
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
"message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
@@ -564,7 +686,6 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if x not in validate_function_inputs:
if isinstance(type_input, list):
if val not in type_input:
input_config = info
@@ -591,18 +712,20 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if len(validate_function_inputs) > 0:
input_data_all = get_input_data(inputs, obj_class, unique_id)
if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs:
if x in validate_function_inputs or validate_has_kwargs:
input_filtered[x] = input_data_all[x]
if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
for x in input_filtered:
for i, r in enumerate(ret):
if r is not True:
if r is not True and not isinstance(r, ExecutionBlocker):
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
@@ -613,8 +736,6 @@ def validate_inputs(prompt, item, validated):
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
@@ -780,7 +901,7 @@ class PromptQueue:
completed: bool
messages: List[str]
def task_done(self, item_id, outputs,
def task_done(self, item_id, history_result,
status: Optional['PromptQueue.ExecutionStatus']):
with self.mutex:
prompt = self.currently_running.pop(item_id)
@@ -793,9 +914,10 @@ class PromptQueue:
self.history[prompt[1]] = {
"prompt": prompt,
"outputs": copy.deepcopy(outputs),
"outputs": {},
'status': status_dict,
}
self.history[prompt[1]].update(history_result)
self.server.queue_updated()
def get_current_queue(self):

View File

@@ -25,11 +25,16 @@ a111:
#comfyui:
# base_path: path/to/comfyui/
# # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
# #is_default: true
# checkpoints: models/checkpoints/
# clip: models/clip/
# clip_vision: models/clip_vision/
# configs: models/configs/
# controlnet: models/controlnet/
# diffusion_models: |
# models/diffusion_models
# models/unet
# embeddings: models/embeddings/
# loras: models/loras/
# upscale_models: models/upscale_models/

View File

@@ -2,7 +2,9 @@ from __future__ import annotations
import os
import time
import mimetypes
import logging
from typing import Set, List, Dict, Tuple, Literal
from collections.abc import Collection
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
@@ -17,7 +19,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions)
folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
@@ -44,6 +46,44 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
class CacheHelper:
"""
Helper class for managing file list cache data.
"""
def __init__(self):
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
self.active = False
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
if not self.active:
return default
return self.cache.get(key, default)
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
if self.active:
self.cache[key] = value
def clear(self):
self.cache.clear()
def __enter__(self):
self.active = True
return self
def __exit__(self, exc_type, exc_value, traceback):
self.active = False
self.clear()
cache_helper = CacheHelper()
extension_mimetypes_cache = {
"webp" : "image",
}
def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"}
return legacy.get(folder_name, folder_name)
if not os.path.exists(input_directory):
try:
os.makedirs(input_directory)
@@ -74,6 +114,13 @@ def get_input_directory() -> str:
global input_directory
return input_directory
def get_user_directory() -> str:
return user_directory
def set_user_directory(user_dir: str) -> None:
global user_directory
user_directory = user_dir
#NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name: str) -> str | None:
@@ -85,6 +132,28 @@ def get_directory_by_type(type_name: str) -> str | None:
return get_input_directory()
return None
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
"""
Example:
files = os.listdir(folder_paths.get_input_directory())
filter_files_content_types(files, ["image", "audio", "video"])
"""
global extension_mimetypes_cache
result = []
for file in files:
extension = file.split('.')[-1]
if extension not in extension_mimetypes_cache:
mime_type, _ = mimetypes.guess_type(file, strict=False)
if not mime_type:
continue
content_type = mime_type.split('/')[0]
extension_mimetypes_cache[extension] = content_type
else:
content_type = extension_mimetypes_cache[extension]
if content_type in content_types:
result.append(file)
return result
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir
@@ -126,14 +195,19 @@ def exists_annotated_filepath(name) -> bool:
return os.path.exists(filepath)
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
def add_model_folder_path(folder_name: str, full_folder_path: str, is_default: bool = False) -> None:
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name in folder_names_and_paths:
folder_names_and_paths[folder_name][0].append(full_folder_path)
if is_default:
folder_names_and_paths[folder_name][0].insert(0, full_folder_path)
else:
folder_names_and_paths[folder_name][0].append(full_folder_path)
else:
folder_names_and_paths[folder_name] = ([full_folder_path], set())
def get_folder_paths(folder_name: str) -> list[str]:
folder_name = map_legacy(folder_name)
return folder_names_and_paths[folder_name][0][:]
def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
@@ -160,8 +234,12 @@ def recursive_search(directory: str, excluded_dir_names: list[str] | None=None)
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
continue
for d in subdirs:
path: str = os.path.join(dirpath, d)
@@ -180,6 +258,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
def get_full_path(folder_name: str, filename: str) -> str | None:
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in folder_names_and_paths:
return None
folders = folder_names_and_paths[folder_name]
@@ -193,7 +272,16 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
return None
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
full_path = get_full_path(folder_name, filename)
if full_path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
return full_path
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths
output_list = set()
folders = folder_names_and_paths[folder_name]
@@ -206,8 +294,13 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
return sorted(list(output_list)), output_folders, time.perf_counter()
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
strong_cache = cache_helper.get(folder_name)
if strong_cache is not None:
return strong_cache
global filename_list_cache
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in filename_list_cache:
return None
out = filename_list_cache[folder_name]
@@ -227,11 +320,13 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float]
return out
def get_filename_list(folder_name: str) -> list[str]:
folder_name = map_legacy(folder_name)
out = cached_filename_list_(folder_name)
if out is None:
out = get_filename_list_(folder_name)
global filename_list_cache
filename_list_cache[folder_name] = out
cache_helper.set(folder_name, out)
return list(out[0])
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
@@ -247,9 +342,17 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
def compute_vars(input: str, image_width: int, image_height: int) -> str:
input = input.replace("%width%", str(image_width))
input = input.replace("%height%", str(image_height))
now = time.localtime()
input = input.replace("%year%", str(now.tm_year))
input = input.replace("%month%", str(now.tm_mon).zfill(2))
input = input.replace("%day%", str(now.tm_mday).zfill(2))
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
input = input.replace("%minute%", str(now.tm_min).zfill(2))
input = input.replace("%second%", str(now.tm_sec).zfill(2))
return input
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
if "%" in filename_prefix:
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix))

Some files were not shown because too many files have changed in this diff Show More