Compare commits

..

148 Commits

Author SHA1 Message Date
Jedrzej Kosinski
e5396e98d8 Add VAELoaderDevice node to device what device to load VAE on 2025-03-21 14:57:05 -05:00
Jedrzej Kosinski
4879b47648 Merge branch 'master' into worksplit-multigpu 2025-03-18 22:19:32 -05:00
comfyanonymous
3b19fc76e3 Allow disabling pe in flux code for some other models. 2025-03-18 05:09:25 -04:00
Jedrzej Kosinski
5ccec33c22 Merge branch 'worksplit-multigpu' of https://github.com/comfyanonymous/ComfyUI into worksplit-multigpu 2025-03-17 14:27:39 -05:00
Jedrzej Kosinski
219d3cd0d0 Merge branch 'master' into worksplit-multigpu 2025-03-17 14:26:35 -05:00
comfyanonymous
50614f1b79 Fix regression with clip vision. 2025-03-17 13:56:11 -04:00
comfyanonymous
6dc7b0bfe3 Add support for giant dinov2 image encoder. 2025-03-17 05:53:54 -04:00
comfyanonymous
e8e990d6b8 Cleanup code. 2025-03-16 06:29:12 -04:00
Jedrzej Kosinski
2e24a15905 Call unpatch_hooks at the start of ModelPatcher.partially_unload (#7253)
* Call unpatch_hooks at the start of ModelPatcher.partially_unload

* Only call unpatch_hooks in partially_unload if lowvram is possible
2025-03-16 06:02:45 -04:00
chaObserv
fd5297131f Guard the edge cases of noise term in er_sde (#7265) 2025-03-16 06:02:25 -04:00
Jedrzej Kosinski
c4ba399475 Merge branch 'master' into worksplit-multigpu 2025-03-15 09:12:09 -05:00
comfyanonymous
55a1b09ddc Allow loading diffusion model files with the "Load Checkpoint" node. 2025-03-15 08:27:49 -04:00
comfyanonymous
3c3988df45 Show a better error message if the VAE is invalid. 2025-03-15 08:26:36 -04:00
Christian Byrne
7ebd8087ff hotfix fe (#7244) 2025-03-15 01:38:10 -04:00
Chenlei Hu
c624c29d66 Update frontend to 1.12.9 (#7236)
* Update frontend to 1.12.9

* Update requirements.txt
2025-03-14 18:17:26 -04:00
comfyanonymous
a2448fc527 Remove useless code. 2025-03-14 18:10:37 -04:00
comfyanonymous
6a0daa79b6 Make the SkipLayerGuidanceDIT node work on WAN. 2025-03-14 10:55:19 -04:00
FeepingCreature
9c98c6358b Tolerate missing @torch.library.custom_op (#7234)
This can happen on Pytorch versions older than 2.4.
2025-03-14 09:51:26 -04:00
FeepingCreature
7aceb9f91c Add --use-flash-attention flag. (#7223)
* Add --use-flash-attention flag.
This is useful on AMD systems, as FA builds are still 10% faster than Pytorch cross-attention.
2025-03-14 03:22:41 -04:00
Jedrzej Kosinski
cc928a786d Merge branch 'master' into worksplit-multigpu 2025-03-13 20:59:11 -05:00
comfyanonymous
35504e2f93 Fix. 2025-03-13 15:03:18 -04:00
comfyanonymous
299436cfed Print mac version. 2025-03-13 10:05:40 -04:00
Chenlei Hu
52e566d2bc Add codeowner for comfy/comfy_types (#7213) 2025-03-12 17:30:00 -04:00
Chenlei Hu
9b6cd9b874 [NodeDef] Add documentation on multi_select input option (#7212) 2025-03-12 17:29:39 -04:00
chaObserv
3fc688aebd Ensure the extra_args in dpmpp sde series (#7204) 2025-03-12 17:28:59 -04:00
comfyanonymous
f4411250f3 Repeat frontend version warning at the end.
This way someone running ComfyUI with the command line is more likely to
actually see it.
2025-03-12 07:13:40 -04:00
Chenlei Hu
d2a0fb6bb0 Add unwrap widget value support (#7197)
* Add unwrap widget value support

* nit
2025-03-12 06:39:14 -04:00
chaObserv
01015bff16 Add er_sde sampler (#7187) 2025-03-12 02:42:37 -04:00
comfyanonymous
2330754b0e Fix error saving some latents. 2025-03-11 15:07:16 -04:00
comfyanonymous
bc219a6487 Merge pull request #7143 from christian-byrne/fix-remote-widget-node
Fix LoadImageOutput node
2025-03-11 04:30:25 -04:00
comfyanonymous
94689766ad Merge pull request #7179 from comfyanonymous/ignore_fe_package
Only check frontend package if using default frontend
2025-03-11 03:45:02 -04:00
huchenlei
cfbe4b49ca Access package version 2025-03-10 20:43:59 -04:00
comfyanonymous
ca8efab79f Support control loras on Wan. 2025-03-10 17:23:13 -04:00
Chenlei Hu
65ea778a5e nit 2025-03-10 15:19:59 -04:00
Chenlei Hu
db9f2a34fc Fix unit test 2025-03-10 15:19:52 -04:00
Chenlei Hu
7946049794 nit 2025-03-10 15:14:40 -04:00
Chenlei Hu
6f6349b6a7 nit 2025-03-10 15:10:40 -04:00
Chenlei Hu
1f138dd382 Only check frontend package if using default frontend 2025-03-10 15:07:44 -04:00
comfyanonymous
b779349b55 Temporarily revert fix to give time for people to update their nodes. 2025-03-10 06:30:17 -04:00
comfyanonymous
35e2dcf5d7 Hack to fix broken manager. 2025-03-10 06:15:17 -04:00
Andrew Kvochko
67c7184b74 ltxv: relax frame_idx divisibility for single frames. (#7146)
This commit relaxes divisibility constraint for single-frame
conditionings. For single frames, the index can be arbitrary, while
multi-frame conditionings (>= 9 frames) must still be aligned to 8
frames.

Co-authored-by: Andrew Kvochko <a.kvochko@lightricks.com>
2025-03-10 04:11:48 -04:00
comfyanonymous
6f8e766509 Prevent custom nodes from accidentally overwriting global modules. 2025-03-10 03:33:41 -04:00
Terry Jia
e1da98a14a remove unused params (#6931) 2025-03-09 14:07:09 -04:00
bymyself
a73410aafa remove overrides 2025-03-09 03:46:08 -07:00
comfyanonymous
9aac21f894 Fix issues with new hunyuan img2vid model and bumb version to v0.3.26 2025-03-09 05:07:22 -04:00
Jedrzej Kosinski
528d1b3563 When cached_hook_patches contain weights for hooks, only use hook_backup for unused keys (#7067) 2025-03-09 04:26:31 -04:00
comfyanonymous
2bc4b5968f ComfyUI version v0.3.25 2025-03-09 03:30:20 -04:00
Jedrzej Kosinski
6e144b98c4 Merge branch 'master' into worksplit-multigpu 2025-03-09 00:00:38 -06:00
comfyanonymous
7395b0c0d1 Support new hunyuan video i2v model.
Use the new "v2 (replace)" guidance type in HunyuanImageToVideo and set
image_interleave to 4 on the "Text Encode Hunyuan Video" node.
2025-03-08 20:34:47 -05:00
comfyanonymous
0952569493 Fix stable cascade VAE on some lowvram machines. 2025-03-08 20:24:04 -05:00
comfyanonymous
29832b3b61 Warn if frontend package is older than the one in requirements.txt 2025-03-08 03:51:36 -05:00
comfyanonymous
be4e760648 Add an image_interleave option to the Hunyuan image to video encode node.
See the tooltip for what it does.
2025-03-07 19:56:26 -05:00
comfyanonymous
c3d9cc4592 Print the frontend version in the log. 2025-03-07 19:56:26 -05:00
Chenlei Hu
84cc9cb528 Update frontend to 1.11.8 (#7119)
* Update frontend to 1.11.7

* Update requirements.txt
2025-03-07 19:02:13 -05:00
comfyanonymous
ebbb920163 Add back taesd to nightly package. 2025-03-07 14:56:09 -05:00
comfyanonymous
d60fe0af4a Reduce size of nightly package. 2025-03-07 08:30:01 -05:00
comfyanonymous
5dbd250965 Update nightly instructions in readme. 2025-03-07 07:57:59 -05:00
comfyanonymous
4ab1875283 Add .bat file to nightly package to run with fp16 accumulation. 2025-03-07 07:45:40 -05:00
comfyanonymous
11b1f27cb1 Set WAN default compute dtype to fp16. 2025-03-07 04:52:36 -05:00
comfyanonymous
70e15fd743 No need for scale_input when fp8 matrix mult is disabled. 2025-03-07 04:49:20 -05:00
comfyanonymous
e1474150de Support fp8_scaled diffusion models that don't use fp8 matrix mult. 2025-03-07 04:39:21 -05:00
JettHu
e62d72e8ca Typo in node_typing.py (#7092) 2025-03-06 15:24:04 -05:00
Dr.Lt.Data
1650cda030 Fixed: Incorrect guide message for missing frontend. (#7105)
`{sys.executable} -m pip` -> `{sys.executable} -s -m pip`

https://github.com/comfyanonymous/ComfyUI/pull/7047#issuecomment-2697876793
2025-03-06 15:23:23 -05:00
comfyanonymous
a13125840c ComfyUI version v0.3.24 2025-03-06 13:53:48 -05:00
comfyanonymous
dfa36e6855 Fix some things breaking when embeddings fail to apply. 2025-03-06 13:31:55 -05:00
comfyanonymous
0124be4d93 ComfyUI version v0.3.23 2025-03-06 04:10:12 -05:00
comfyanonymous
29a70ca101 Support HunyuanVideo image to video model. 2025-03-06 03:07:15 -05:00
comfyanonymous
0bef826a98 Support llava clip vision model. 2025-03-06 00:24:43 -05:00
comfyanonymous
85ef295069 Make applying embeddings more efficient.
Adding new tokens no longer makes a whole copy of the embeddings weight
which can be massive on certain models.
2025-03-05 17:34:38 -05:00
Chenlei Hu
5d84607bf3 Add type hint for FileLocator (#6968)
* Add type hint for FileLocator

* nit
2025-03-05 15:35:26 -05:00
Silver
c1909f350f Better argument handling of front-end-root (#7043)
* Better argument handling of front-end-root

Improves handling of front-end-root launch argument. Several instances where users have set it and ComfyUI launches as normal and completely disregards the launch arg which doesn't make sense. Better to indicate to user that something is incorrect.

* Removed unused import

There was no real reason to use "Optional" typing in ther front-end-root argument.
2025-03-05 15:34:22 -05:00
Chenlei Hu
52b3469606 [NodeDef] Explicitly add control_after_generate to seed/noise_seed (#7059)
* [NodeDef] Explicitly add control_after_generate to seed/noise_seed

* Update comfy/comfy_types/node_typing.py

Co-authored-by: filtered <176114999+webfiltered@users.noreply.github.com>

---------

Co-authored-by: filtered <176114999+webfiltered@users.noreply.github.com>
2025-03-05 15:33:23 -05:00
comfyanonymous
889519971f Bump ComfyUI version to v0.3.22 2025-03-05 10:06:37 -05:00
comfyanonymous
76739c23c3 Revert "Partially revert last commit."
This reverts commit a80bc822a2.
2025-03-05 09:57:40 -05:00
comfyanonymous
a80bc822a2 Partially revert last commit. 2025-03-05 08:58:44 -05:00
Andrew Kvochko
872780d236 fix: ltxv crop guides works with 0 keyframes (#7085)
This patch fixes a bug in LTXVCropGuides when the latent has no
keyframes. Additionally, the first frame is always added as a keyframe.

Co-authored-by: Andrew Kvochko <a.kvochko@lightricks.com>
2025-03-05 08:47:32 -05:00
comfyanonymous
6d45ffbe23 Bump ComfyUI version to v0.3.21 2025-03-05 08:05:22 -05:00
comfyanonymous
77633ba77d Remove unused variable. 2025-03-05 07:31:47 -05:00
comfyanonymous
30e6cfb1a0 Fix LTXVPreprocess on resolutions that are not multiples of 2. 2025-03-05 07:18:13 -05:00
comfyanonymous
dc134b2fdb Bump ComfyUI version to v0.3.20 2025-03-05 06:28:14 -05:00
comfyanonymous
369b079ff6 Fix lowvram issue with ltxv vae. 2025-03-05 05:26:08 -05:00
comfyanonymous
9c9a7f012a Adjust ltxv memory factor. 2025-03-05 05:16:05 -05:00
comfyanonymous
93fedd92fe Support LTXV 0.9.5.
Credits: Lightricks team.
2025-03-05 00:13:49 -05:00
comfyanonymous
745b13649b Add update instructions for the portable. 2025-03-04 23:34:36 -05:00
Dr.Lt.Data
2b140654c7 suggest absolute full path to the requirements.txt instead of just requirements.txt (#7079)
For users of the portable version, there are occasional instances where commands are misinterpreted.
2025-03-04 23:29:34 -05:00
comfyanonymous
65042f7d39 Make it easier to set a custom template for hunyuan video. 2025-03-04 09:26:05 -05:00
comfyanonymous
7c7c70c400 Refactor skyreels i2v code. 2025-03-04 00:15:45 -05:00
Jedrzej Kosinski
6dca17bd2d Satisfy ruff linting 2025-03-03 23:08:29 -06:00
Jedrzej Kosinski
5080105c23 Merge branch 'master' into worksplit-multigpu 2025-03-03 22:56:53 -06:00
Jedrzej Kosinski
093914a247 Made MultiGPU Work Units node more robust by forcing ModelPatcher clones to match at sample time, reuse loaded MultiGPU clones, finalize MultiGPU Work Units node ID and name, small refactors/cleanup of logging and multigpu-related code 2025-03-03 22:56:13 -06:00
comfyanonymous
8362199ee7 Bump ComfyUI version to v0.3.19 2025-03-03 19:18:37 -05:00
comfyanonymous
f86c724ef2 Temporal area composition.
New ConditioningSetAreaPercentageVideo node.
2025-03-03 06:50:31 -05:00
Dr.Lt.Data
d6e5d487ad improved: better frontend package installation guide (#7047)
* improved: better installation guide
- change `pip` to `{sys.executable} -m pip`
modified: To prevent the guide message from being obscured by a complex error message, apply `exit` instead of `raise`.

* ruff fix
2025-03-03 04:40:23 -05:00
comfyanonymous
6752a826f6 Make the missing frontend package error more obvious. 2025-03-02 15:43:56 -05:00
Jedrzej Kosinski
605893d3cf Merge branch 'master' into worksplit-multigpu 2025-02-24 19:23:16 -06:00
Jedrzej Kosinski
048f4f0b3a Merge branch 'master' into worksplit-multigpu 2025-02-17 19:35:58 -06:00
Jedrzej Kosinski
d2504fb701 Merge branch 'master' into worksplit-multigpu 2025-02-11 22:34:51 -06:00
Jedrzej Kosinski
b03763bca6 Merge branch 'multigpu_support' into worksplit-multigpu 2025-02-07 13:27:49 -06:00
Jedrzej Kosinski
476aa79b64 Let --cuda-device take in a string to allow multiple devices (or device order) to be chosen, print available devices on startup, potentially support MultiGPU Intel and Ascend setups 2025-02-06 08:44:07 -06:00
Jedrzej Kosinski
441cfd1a7a Merge branch 'master' into multigpu_support 2025-02-06 08:10:48 -06:00
Jedrzej Kosinski
99a5c1068a Merge branch 'master' into multigpu_support 2025-02-02 03:19:18 -06:00
Jedrzej Kosinski
02747cde7d Carry over change from _calc_cond_batch into _calc_cond_batch_multigpu 2025-01-29 11:10:23 -06:00
Jedrzej Kosinski
0b3233b4e2 Merge remote-tracking branch 'origin/master' into multigpu_support 2025-01-28 06:11:07 -06:00
Jedrzej Kosinski
eda866bf51 Extracted multigpu core code into multigpu.py, added load_balance_devices to get subdivision of work based on available devices and splittable work item count, added MultiGPU Options nodes to set relative_speed of specific devices; does not change behavior yet 2025-01-27 06:25:48 -06:00
Jedrzej Kosinski
e3298b84de Create proper MultiGPU Initialize node, create gpu_options to create scaffolding for asymmetrical GPU support 2025-01-26 09:34:20 -06:00
Jedrzej Kosinski
c7feef9060 Cast transformer_options for multigpu 2025-01-26 05:29:27 -06:00
Jedrzej Kosinski
51af7fa1b4 Fix multigpu ControlBase get_models and cleanup calls to avoid multiple calls of functions on multigpu_clones versions of controlnets 2025-01-25 06:05:01 -06:00
Jedrzej Kosinski
46969c380a Initial MultiGPU support for controlnets 2025-01-24 05:39:38 -06:00
Jedrzej Kosinski
5db4277449 Make sure additional_models are unloaded as well when perform 2025-01-23 19:06:05 -06:00
Jedrzej Kosinski
02a4d0ad7d Added unload_model_and_clones to model_management.py to allow unloading only relevant models 2025-01-23 01:20:00 -06:00
Jedrzej Kosinski
ef137ac0b6 Merge branch 'multigpu_support' of https://github.com/kosinkadink/ComfyUI into multigpu_support 2025-01-20 04:34:39 -06:00
Jedrzej Kosinski
328d4f16a9 Make WeightHooks compatible with MultiGPU, clean up some code 2025-01-20 04:34:26 -06:00
Jedrzej Kosinski
bdbcb85b8d Merge branch 'multigpu_support' of https://github.com/Kosinkadink/ComfyUI into multigpu_support 2025-01-20 00:51:42 -06:00
Jedrzej Kosinski
6c9e94bae7 Merge branch 'master' into multigpu_support 2025-01-20 00:51:37 -06:00
Jedrzej Kosinski
bfce723311 Initial work on multigpu_clone function, which will account for additional_models getting cloned 2025-01-17 03:31:28 -06:00
Jedrzej Kosinski
31f5458938 Merge branch 'master' into multigpu_support 2025-01-16 18:25:05 -06:00
Jedrzej Kosinski
2145a202eb Merge branch 'master' into multigpu_support 2025-01-15 19:58:28 -06:00
Jedrzej Kosinski
25818dc848 Added a 'max_gpus' input 2025-01-14 13:45:14 -06:00
Jedrzej Kosinski
198953cd08 Add nodes_multigpu.py to loaded nodes 2025-01-14 12:24:55 -06:00
Jedrzej Kosinski
ec16ee2f39 Merge branch 'master' into multigpu_support 2025-01-13 20:21:06 -06:00
Jedrzej Kosinski
d5088072fb Make test node for multigpu instead of storing it in just a local __init__.py 2025-01-13 20:20:25 -06:00
Jedrzej Kosinski
8d4b50158e Merge branch 'master' into multigpu_support 2025-01-11 20:16:42 -06:00
Jedrzej Kosinski
e88c6c03ff Fix cond_cat to not try to cast anything that doesn't have a 'to' function 2025-01-10 23:05:24 -06:00
Jedrzej Kosinski
d3cf2b7b24 Merge branch 'comfyanonymous:master' into multigpu_support 2025-01-10 20:24:37 -06:00
Jedrzej Kosinski
7448f02b7c Initial proof of concept of giving splitting cond sampling between multiple GPUs 2025-01-08 03:33:05 -06:00
Jedrzej Kosinski
871258aa72 Add get_all_torch_devices to get detected devices intended for current torch hardware device 2025-01-07 21:06:03 -06:00
Jedrzej Kosinski
66838ebd39 Merge branch 'comfyanonymous:master' into multigpu_support 2025-01-07 20:11:27 -06:00
Jedrzej Kosinski
7333281698 Clean up a typehint 2025-01-07 02:58:59 -06:00
Jedrzej Kosinski
3cd4c5cb0a Rename AddModelsHooks to AdditionalModelsHook, rename SetInjectionsHook to InjectionsHook (not yet implemented, but at least getting the naming figured out) 2025-01-07 02:22:49 -06:00
Jedrzej Kosinski
11c6d56037 Merge branch 'master' into hooks_part2 2025-01-07 01:01:53 -06:00
Jedrzej Kosinski
216fea15ee Made TransformerOptionsHook contribute to registered hooks properly, added some doc strings and removed a so-far unused variable 2025-01-07 00:59:18 -06:00
Jedrzej Kosinski
58bf8815c8 Add a get_injections function to ModelPatcher 2025-01-06 20:34:30 -06:00
Jedrzej Kosinski
1b38f5bf57 removed 4 whitespace lines to satisfy Ruff, 2025-01-06 17:11:12 -06:00
Jedrzej Kosinski
2724ac4a60 Merge branch 'master' into hooks_part2 2025-01-06 17:04:24 -06:00
Jedrzej Kosinski
f48f90e471 Make hook_scope functional for TransformerOptionsHook 2025-01-06 02:23:04 -06:00
Jedrzej Kosinski
6463c39ce0 Merge branch 'master' into hooks_part2 2025-01-06 01:28:26 -06:00
Jedrzej Kosinski
0a7e2ae787 Filter only registered hooks on self.conds in CFGGuider.sample 2025-01-06 01:04:29 -06:00
Jedrzej Kosinski
03a97b604a Fix performance of hooks when hooks are appended via Cond Pair Set Props nodes by properly caching between positive and negative conds, make hook_patches_backup behave as intended (in the case that something pre-registers WeightHooks on the ModelPatcher instead of registering it at sample time) 2025-01-06 01:03:59 -06:00
Jedrzej Kosinski
4446c86052 Made hook clone code sane, made clear ObjectPatchHook and SetInjectionsHook are not yet operational 2025-01-05 22:25:51 -06:00
Jedrzej Kosinski
8270ff312f Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time 2025-01-05 21:07:02 -06:00
Jedrzej Kosinski
db2d7ad9ba Merge branch 'add_sample_sigmas' into hooks_part2 2025-01-05 15:45:13 -06:00
Jedrzej Kosinski
6620d86318 In inner_sample, change "sigmas" to "sampler_sigmas" in transformer_options to not conflict with the "sigmas" that will overwrite "sigmas" in _calc_cond_batch 2025-01-05 15:26:22 -06:00
Jedrzej Kosinski
111fd0cadf Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type 2025-01-04 02:04:07 -06:00
Jedrzej Kosinski
776aa734e1 Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options) 2025-01-04 01:02:21 -06:00
Jedrzej Kosinski
5a2ad032cb Cleaned up hooks.py, refactored Hook.should_register and add_hook_patches to use target_dict instead of target so that more information can be provided about the current execution environment if needed 2025-01-03 20:02:27 -06:00
Jedrzej Kosinski
d44295ef71 Merge branch 'master' into hooks_part2 2025-01-03 18:28:31 -06:00
Jedrzej Kosinski
bf21be066f Merge branch 'master' into hooks_part2 2024-12-30 14:16:22 -06:00
Jedrzej Kosinski
72bbf49349 Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas' 2024-12-29 15:49:09 -06:00
71 changed files with 2549 additions and 479 deletions

View File

@@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
pause

View File

@@ -22,7 +22,7 @@ on:
description: 'Python patch version' description: 'Python patch version'
required: true required: true
type: string type: string
default: "8" default: "9"
jobs: jobs:

View File

@@ -29,7 +29,7 @@ on:
description: 'python patch version' description: 'python patch version'
required: true required: true
type: string type: string
default: "8" default: "9"
# push: # push:
# branches: # branches:
# - master # - master

View File

@@ -7,7 +7,7 @@ on:
description: 'cuda version' description: 'cuda version'
required: true required: true
type: string type: string
default: "126" default: "128"
python_minor: python_minor:
description: 'python minor version' description: 'python minor version'
@@ -19,7 +19,7 @@ on:
description: 'python patch version' description: 'python patch version'
required: true required: true
type: string type: string
default: "1" default: "2"
# push: # push:
# branches: # branches:
# - master # - master
@@ -34,7 +34,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 30
persist-credentials: false persist-credentials: false
- uses: actions/setup-python@v5 - uses: actions/setup-python@v5
with: with:
@@ -74,7 +74,7 @@ jobs:
pause" > ./update/update_comfyui_and_python_dependencies.bat pause" > ./update/update_comfyui_and_python_dependencies.bat
cd .. cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
cd ComfyUI_windows_portable_nightly_pytorch cd ComfyUI_windows_portable_nightly_pytorch

View File

@@ -19,7 +19,7 @@ on:
description: 'python patch version' description: 'python patch version'
required: true required: true
type: string type: string
default: "8" default: "9"
# push: # push:
# branches: # branches:
# - master # - master

View File

@@ -1,13 +0,0 @@
repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.241 # Use the desired version of Ruff
hooks:
- id: ruff
- repo: local
hooks:
- id: pytest
name: Run Pytest
entry: pytest
language: system
types: [python]

View File

@@ -19,5 +19,6 @@
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata /app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata /utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
# Extra nodes # Node developers
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink /comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered

View File

@@ -215,9 +215,9 @@ Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126``` ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
This is the command to install pytorch nightly instead which might have performance improvements: This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
#### Troubleshooting #### Troubleshooting
@@ -330,25 +330,6 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
See also: [https://www.comfy.org/](https://www.comfy.org/) See also: [https://www.comfy.org/](https://www.comfy.org/)
## ComfyUI Backend Development
### Setup Environment
Install pre-commit to run tests and linters
```
pip install pre-commit
```
```
pre-commit install
```
### Reporting Issues and Requesting Features
For any bugs, issues, or feature requests related to the backend, please use the [ComfyUI repository](https://github.com/comfyanonymous/ComfyUI). This will help us manage and address backend-specific concerns more efficiently.
## Frontend Development ## Frontend Development
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory. As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.

View File

@@ -3,6 +3,7 @@ import argparse
import logging import logging
import os import os
import re import re
import sys
import tempfile import tempfile
import zipfile import zipfile
import importlib import importlib
@@ -10,19 +11,43 @@ from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import TypedDict, Optional from typing import TypedDict, Optional
from importlib.metadata import version
import requests import requests
from typing_extensions import NotRequired from typing_extensions import NotRequired
from comfy.cli_args import DEFAULT_VERSION_STRING from comfy.cli_args import DEFAULT_VERSION_STRING
import app.logger
# The path to the requirements.txt file
req_path = Path(__file__).parents[1] / "requirements.txt"
def frontend_install_warning_message():
"""The warning message to display when the frontend version is not up to date."""
extra = ""
if sys.flags.no_user_site:
extra = "-s "
return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem"
try: def check_frontend_version():
import comfyui_frontend_package """Check if the frontend version is up to date."""
except ImportError as e:
# TODO: Remove the check after roll out of 0.3.16 def parse_version(version: str) -> tuple[int, int, int]:
logging.error("comfyui-frontend-package is not installed. Please install the updated requirements.txt file by running: pip install -r requirements.txt") return tuple(map(int, version.split(".")))
raise e
try:
frontend_version_str = version("comfyui-frontend-package")
frontend_version = parse_version(frontend_version_str)
with open(req_path, "r", encoding="utf-8") as f:
required_frontend = parse_version(f.readline().split("=")[-1])
if frontend_version < required_frontend:
app.logger.log_startup_warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message()))
else:
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
except Exception as e:
logging.error(f"Failed to check frontend version: {e}")
REQUEST_TIMEOUT = 10 # seconds REQUEST_TIMEOUT = 10 # seconds
@@ -119,9 +144,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager: class FrontendManager:
DEFAULT_FRONTEND_PATH = str(importlib.resources.files(comfyui_frontend_package) / "static")
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions") CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
def default_frontend_path(cls) -> str:
try:
import comfyui_frontend_package
return str(importlib.resources.files(comfyui_frontend_package) / "static")
except ImportError:
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n")
sys.exit(-1)
@classmethod @classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]: def parse_version_string(cls, value: str) -> tuple[str, str, str]:
""" """
@@ -158,7 +191,8 @@ class FrontendManager:
main error source might be request timeout or invalid URL. main error source might be request timeout or invalid URL.
""" """
if version_string == DEFAULT_VERSION_STRING: if version_string == DEFAULT_VERSION_STRING:
return cls.DEFAULT_FRONTEND_PATH check_frontend_version()
return cls.default_frontend_path()
repo_owner, repo_name, version = cls.parse_version_string(version_string) repo_owner, repo_name, version = cls.parse_version_string(version_string)
@@ -211,4 +245,5 @@ class FrontendManager:
except Exception as e: except Exception as e:
logging.error("Failed to initialize frontend: %s", e) logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.") logging.info("Falling back to the default frontend.")
return cls.DEFAULT_FRONTEND_PATH check_frontend_version()
return cls.default_frontend_path()

View File

@@ -82,3 +82,17 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
logger.addHandler(stdout_handler) logger.addHandler(stdout_handler)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
STARTUP_WARNINGS = []
def log_startup_warning(msg):
logging.warning(msg)
STARTUP_WARNINGS.append(msg)
def print_startup_warnings():
for s in STARTUP_WARNINGS:
logging.warning(s)
STARTUP_WARNINGS.clear()

View File

@@ -1,7 +1,6 @@
import argparse import argparse
import enum import enum
import os import os
from typing import Optional
import comfy.options import comfy.options
@@ -50,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.") parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use.")
cm_group = parser.add_mutually_exclusive_group() cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.") cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
@@ -107,6 +106,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
@@ -166,13 +166,14 @@ parser.add_argument(
""", """,
) )
def is_valid_directory(path: Optional[str]) -> Optional[str]: def is_valid_directory(path: str) -> str:
"""Validate if the given path is a directory.""" """Validate if the given path is a directory, and check permissions."""
if path is None: if not os.path.exists(path):
return None raise argparse.ArgumentTypeError(f"The path '{path}' does not exist.")
if not os.path.isdir(path): if not os.path.isdir(path):
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.") raise argparse.ArgumentTypeError(f"'{path}' is not a directory.")
if not os.access(path, os.R_OK):
raise argparse.ArgumentTypeError(f"You do not have read permissions for '{path}'.")
return path return path
parser.add_argument( parser.add_argument(

View File

@@ -97,8 +97,12 @@ class CLIPTextModel_(torch.nn.Module):
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32): def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
if embeds is not None:
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
else:
x = self.embeddings(input_tokens, dtype=dtype) x = self.embeddings(input_tokens, dtype=dtype)
mask = None mask = None
if attention_mask is not None: if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
@@ -116,6 +120,9 @@ class CLIPTextModel_(torch.nn.Module):
if i is not None and final_layer_norm_intermediate: if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i) i = self.final_layer_norm(i)
if num_tokens is not None:
pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
else:
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),] pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
return x, i, pooled_output return x, i, pooled_output
@@ -204,6 +211,15 @@ class CLIPVision(torch.nn.Module):
pooled_output = self.post_layernorm(x[:, 0, :]) pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output return x, i, pooled_output
class LlavaProjector(torch.nn.Module):
def __init__(self, in_dim, out_dim, dtype, device, operations):
super().__init__()
self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
def forward(self, x):
return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
class CLIPVisionModelProjection(torch.nn.Module): class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
@@ -213,7 +229,16 @@ class CLIPVisionModelProjection(torch.nn.Module):
else: else:
self.visual_projection = lambda a: a self.visual_projection = lambda a: a
if "llava3" == config_dict.get("projector_type", None):
self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
else:
self.multi_modal_projector = None
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs) x = self.vision_model(*args, **kwargs)
out = self.visual_projection(x[2]) out = self.visual_projection(x[2])
return (x[0], x[1], out) projected = None
if self.multi_modal_projector is not None:
projected = self.multi_modal_projector(x[1])
return (x[0], x[1], out, projected)

View File

@@ -9,6 +9,7 @@ import comfy.model_patcher
import comfy.model_management import comfy.model_management
import comfy.utils import comfy.utils
import comfy.clip_model import comfy.clip_model
import comfy.image_encoders.dino2
class Output: class Output:
def __getitem__(self, key): def __getitem__(self, key):
@@ -34,6 +35,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
image = torch.clip((255. * image), 0, 255).round() / 255.0 image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1]) return (image - mean.view([3,1,1])) / std.view([3,1,1])
IMAGE_ENCODERS = {
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
}
class ClipVisionModel(): class ClipVisionModel():
def __init__(self, json_config): def __init__(self, json_config):
with open(json_config) as f: with open(json_config) as f:
@@ -42,10 +49,11 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224) self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
self.load_device = comfy.model_management.text_encoder_device() self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device() offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval() self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
@@ -65,6 +73,7 @@ class ClipVisionModel():
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
outputs["mm_projected"] = out[3]
return outputs return outputs
def convert_to_transformers(sd, prefix): def convert_to_transformers(sd, prefix):
@@ -104,9 +113,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
if "multi_modal_projector.linear_1.bias" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else: else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
elif "embeddings.patch_embeddings.projection.weight" in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
else: else:
return None return None

View File

@@ -0,0 +1,19 @@
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"image_size": 336,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-5,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"projector_type": "llava3",
"torch_dtype": "float32"
}

View File

@@ -1,6 +1,6 @@
import torch import torch
from typing import Callable, Protocol, TypedDict, Optional, List from typing import Callable, Protocol, TypedDict, Optional, List
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin, FileLocator
class UnetApplyFunction(Protocol): class UnetApplyFunction(Protocol):
@@ -42,4 +42,5 @@ __all__ = [
InputTypeDict.__name__, InputTypeDict.__name__,
ComfyNodeABC.__name__, ComfyNodeABC.__name__,
CheckLazyMixin.__name__, CheckLazyMixin.__name__,
FileLocator.__name__,
] ]

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Literal, TypedDict from typing import Literal, TypedDict
from typing_extensions import NotRequired
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
@@ -26,6 +27,7 @@ class IO(StrEnum):
BOOLEAN = "BOOLEAN" BOOLEAN = "BOOLEAN"
INT = "INT" INT = "INT"
FLOAT = "FLOAT" FLOAT = "FLOAT"
COMBO = "COMBO"
CONDITIONING = "CONDITIONING" CONDITIONING = "CONDITIONING"
SAMPLER = "SAMPLER" SAMPLER = "SAMPLER"
SIGMAS = "SIGMAS" SIGMAS = "SIGMAS"
@@ -66,6 +68,7 @@ class IO(StrEnum):
b = frozenset(value.split(",")) b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b)) return not (b.issubset(a) or a.issubset(b))
class RemoteInputOptions(TypedDict): class RemoteInputOptions(TypedDict):
route: str route: str
"""The route to the remote source.""" """The route to the remote source."""
@@ -80,6 +83,14 @@ class RemoteInputOptions(TypedDict):
refresh: int refresh: int
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed.""" """The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
class MultiSelectOptions(TypedDict):
placeholder: NotRequired[str]
"""The placeholder text to display in the multi-select widget when no items are selected."""
chip: NotRequired[bool]
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
class InputTypeOptions(TypedDict): class InputTypeOptions(TypedDict):
"""Provides type hinting for the return type of the INPUT_TYPES node function. """Provides type hinting for the return type of the INPUT_TYPES node function.
@@ -114,7 +125,7 @@ class InputTypeOptions(TypedDict):
# default: bool # default: bool
label_on: str label_on: str
"""The label to use in the UI when the bool is True (``BOOLEAN``)""" """The label to use in the UI when the bool is True (``BOOLEAN``)"""
label_on: str label_off: str
"""The label to use in the UI when the bool is False (``BOOLEAN``)""" """The label to use in the UI when the bool is False (``BOOLEAN``)"""
# class InputTypeString(InputTypeOptions): # class InputTypeString(InputTypeOptions):
# default: str # default: str
@@ -133,7 +144,22 @@ class InputTypeOptions(TypedDict):
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag. """Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
""" """
remote: RemoteInputOptions remote: RemoteInputOptions
"""Specifies the configuration for a remote input.""" """Specifies the configuration for a remote input.
Available after ComfyUI frontend v1.9.7
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
control_after_generate: bool
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
options: NotRequired[list[str | int | float]]
"""COMBO type only. Specifies the selectable options for the combo widget.
Prefer:
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
Over:
[["Option 1", "Option 2", "Option 3"]]
"""
multi_select: NotRequired[MultiSelectOptions]
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
class HiddenInputTypeDict(TypedDict): class HiddenInputTypeDict(TypedDict):
@@ -293,3 +319,14 @@ class CheckLazyMixin:
need = [name for name in kwargs if kwargs[name] is None] need = [name for name in kwargs if kwargs[name] is None]
return need return need
class FileLocator(TypedDict):
"""Provides type hinting for the file location"""
filename: str
"""The filename of the file."""
subfolder: str
"""The subfolder of the file."""
type: Literal["input", "output", "temp"]
"""The root folder of the file."""

View File

@@ -15,13 +15,14 @@
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
from __future__ import annotations
import torch import torch
from enum import Enum from enum import Enum
import math import math
import os import os
import logging import logging
import copy
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
import comfy.model_detection import comfy.model_detection
@@ -36,7 +37,7 @@ import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet import comfy.ldm.flux.controlnet
import comfy.cldm.dit_embedder import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from comfy.hooks import HookGroup from comfy.hooks import HookGroup
@@ -63,6 +64,18 @@ class StrengthType(Enum):
CONSTANT = 1 CONSTANT = 1
LINEAR_UP = 2 LINEAR_UP = 2
class ControlIsolation:
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
def __init__(self, control: ControlBase):
self.control = control
self.orig_previous_controlnet = control.previous_controlnet
def __enter__(self):
self.control.previous_controlnet = None
def __exit__(self, *args):
self.control.previous_controlnet = self.orig_previous_controlnet
class ControlBase: class ControlBase:
def __init__(self): def __init__(self):
self.cond_hint_original = None self.cond_hint_original = None
@@ -76,7 +89,7 @@ class ControlBase:
self.compression_ratio = 8 self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact' self.upscale_algorithm = 'nearest-exact'
self.extra_args = {} self.extra_args = {}
self.previous_controlnet = None self.previous_controlnet: Union[ControlBase, None] = None
self.extra_conds = [] self.extra_conds = []
self.strength_type = StrengthType.CONSTANT self.strength_type = StrengthType.CONSTANT
self.concat_mask = False self.concat_mask = False
@@ -84,6 +97,7 @@ class ControlBase:
self.extra_concat = None self.extra_concat = None
self.extra_hooks: HookGroup = None self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a self.preprocess_image = lambda a: a
self.multigpu_clones: dict[torch.device, ControlBase] = {}
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
@@ -110,17 +124,38 @@ class ControlBase:
def cleanup(self): def cleanup(self):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
self.previous_controlnet.cleanup() self.previous_controlnet.cleanup()
for device_cnet in self.multigpu_clones.values():
with ControlIsolation(device_cnet):
device_cnet.cleanup()
self.cond_hint = None self.cond_hint = None
self.extra_concat = None self.extra_concat = None
self.timestep_range = None self.timestep_range = None
def get_models(self): def get_models(self):
out = [] out = []
for device_cnet in self.multigpu_clones.values():
out += device_cnet.get_models_only_self()
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models() out += self.previous_controlnet.get_models()
return out return out
def get_models_only_self(self):
'Calls get_models, but temporarily sets previous_controlnet to None.'
with ControlIsolation(self):
return self.get_models()
def get_instance_for_device(self, device):
'Returns instance of this Control object intended for selected device.'
return self.multigpu_clones.get(device, self)
def deepclone_multigpu(self, load_device, autoregister=False):
'''
Create deep clone of Control object where model(s) is set to other devices.
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
'''
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
def get_extra_hooks(self): def get_extra_hooks(self):
out = [] out = []
if self.extra_hooks is not None: if self.extra_hooks is not None:
@@ -129,7 +164,7 @@ class ControlBase:
out += self.previous_controlnet.get_extra_hooks() out += self.previous_controlnet.get_extra_hooks()
return out return out
def copy_to(self, c): def copy_to(self, c: ControlBase):
c.cond_hint_original = self.cond_hint_original c.cond_hint_original = self.cond_hint_original
c.strength = self.strength c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range c.timestep_percent_range = self.timestep_percent_range
@@ -280,6 +315,14 @@ class ControlNet(ControlBase):
self.copy_to(c) self.copy_to(c)
return c return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.control_model = copy.deepcopy(c.control_model)
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
if autoregister:
self.multigpu_clones[load_device] = c
return c
def get_models(self): def get_models(self):
out = super().get_models() out = super().get_models()
out.append(self.control_model_wrapped) out.append(self.control_model_wrapped)
@@ -804,6 +847,14 @@ class T2IAdapter(ControlBase):
self.copy_to(c) self.copy_to(c)
return c return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.t2i_model = copy.deepcopy(c.t2i_model)
c.device = load_device
if autoregister:
self.multigpu_clones[load_device] = c
return c
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8 compression_ratio = 8
upscale_algorithm = 'nearest-exact' upscale_algorithm = 'nearest-exact'

View File

@@ -0,0 +1,141 @@
import torch
from comfy.text_encoders.bert import BertAttention
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
class Dino2AttentionOutput(torch.nn.Module):
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
def forward(self, x):
return self.dense(x)
class Dino2AttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
def forward(self, x, mask, optimized_attention):
return self.output(self.attention(x, mask, optimized_attention))
class LayerScale(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
class SwiGLUFFN(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
in_features = out_features = dim
hidden_features = int(dim * 4)
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
def forward(self, x):
x = self.weights_in(x)
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
return self.weights_out(x)
class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, x, optimized_attention):
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
return x
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
super().__init__()
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
def forward(self, x, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
for i, l in enumerate(self.layer):
x = l(x, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class Dino2PatchEmbeddings(torch.nn.Module):
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
super().__init__()
self.projection = operations.Conv2d(
in_channels=num_channels,
out_channels=dim,
kernel_size=patch_size,
stride=patch_size,
bias=True,
dtype=dtype,
device=device
)
def forward(self, pixel_values):
return self.projection(pixel_values).flatten(2).transpose(1, 2)
class Dino2Embeddings(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
patch_size = 14
image_size = 518
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
def forward(self, pixel_values):
x = self.patch_embeddings(pixel_values)
# TODO: mask_token?
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
return x
class Dinov2Model(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
num_layers = config_dict["num_hidden_layers"]
dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values)
x, i = self.encoder(x, intermediate_output=intermediate_output)
x = self.layernorm(x)
pooled_output = x[:, 0, :]
return x, i, pooled_output, None

View File

@@ -0,0 +1,21 @@
{
"attention_probs_dropout_prob": 0.0,
"drop_path_rate": 0.0,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 1536,
"image_size": 518,
"initializer_range": 0.02,
"layer_norm_eps": 1e-06,
"layerscale_value": 1.0,
"mlp_ratio": 4,
"model_type": "dinov2",
"num_attention_heads": 24,
"num_channels": 3,
"num_hidden_layers": 40,
"patch_size": 14,
"qkv_bias": true,
"use_swiglu_ffn": true,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225]
}

View File

@@ -688,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
if len(sigmas) <= 1: if len(sigmas) <= 1:
return x return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp() sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg() t_fn = lambda sigma: sigma.log().neg()
@@ -762,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
if solver_type not in {'heun', 'midpoint'}: if solver_type not in {'heun', 'midpoint'}:
raise ValueError('solver_type must be \'heun\' or \'midpoint\'') raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
old_denoised = None old_denoised = None
@@ -808,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
if len(sigmas) <= 1: if len(sigmas) <= 1:
return x return x
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
denoised_1, denoised_2 = None, None denoised_1, denoised_2 = None, None
@@ -858,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1: if len(sigmas) <= 1:
return x return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@@ -867,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1: if len(sigmas) <= 1:
return x return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@@ -876,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
if len(sigmas) <= 1: if len(sigmas) <= 1:
return x return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
@@ -1366,3 +1366,59 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
x = x + d_bar * dt x = x + d_bar * dt
old_d = d old_d = d
return x return x
@torch.no_grad()
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
"""
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
def default_noise_scaler(sigma):
return sigma * ((sigma ** 0.3).exp() + 10.0)
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
num_integration_points = 200.0
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
old_denoised = None
old_denoised_d = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
stage_used = min(max_stage, i + 1)
if sigmas[i + 1] == 0:
x = denoised
elif stage_used == 1:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
else:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
dt = sigmas[i + 1] - sigmas[i]
sigma_step_size = -dt / num_integration_points
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
scaled_pos = noise_scaler(sigma_pos)
# Stage 2
s = torch.sum(1 / scaled_pos) * sigma_step_size
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
if stage_used >= 3:
# Stage 3
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
old_denoised_d = denoised_d
if s_noise != 0 and sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
old_denoised = denoised
return x

View File

@@ -19,6 +19,10 @@
import torch import torch
from torch import nn from torch import nn
from torch.autograd import Function from torch.autograd import Function
import comfy.ops
ops = comfy.ops.disable_weight_init
class vector_quantize(Function): class vector_quantize(Function):
@staticmethod @staticmethod
@@ -121,15 +125,15 @@ class ResBlock(nn.Module):
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.depthwise = nn.Sequential( self.depthwise = nn.Sequential(
nn.ReplicationPad2d(1), nn.ReplicationPad2d(1),
nn.Conv2d(c, c, kernel_size=3, groups=c) ops.Conv2d(c, c, kernel_size=3, groups=c)
) )
# channelwise # channelwise
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential( self.channelwise = nn.Sequential(
nn.Linear(c, c_hidden), ops.Linear(c, c_hidden),
nn.GELU(), nn.GELU(),
nn.Linear(c_hidden, c), ops.Linear(c_hidden, c),
) )
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
@@ -171,16 +175,16 @@ class StageA(nn.Module):
# Encoder blocks # Encoder blocks
self.in_block = nn.Sequential( self.in_block = nn.Sequential(
nn.PixelUnshuffle(2), nn.PixelUnshuffle(2),
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) ops.Conv2d(3 * 4, c_levels[0], kernel_size=1)
) )
down_blocks = [] down_blocks = []
for i in range(levels): for i in range(levels):
if i > 0: if i > 0:
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
block = ResBlock(c_levels[i], c_levels[i] * 4) block = ResBlock(c_levels[i], c_levels[i] * 4)
down_blocks.append(block) down_blocks.append(block)
down_blocks.append(nn.Sequential( down_blocks.append(nn.Sequential(
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
)) ))
self.down_blocks = nn.Sequential(*down_blocks) self.down_blocks = nn.Sequential(*down_blocks)
@@ -191,7 +195,7 @@ class StageA(nn.Module):
# Decoder blocks # Decoder blocks
up_blocks = [nn.Sequential( up_blocks = [nn.Sequential(
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) ops.Conv2d(c_latent, c_levels[-1], kernel_size=1)
)] )]
for i in range(levels): for i in range(levels):
for j in range(bottleneck_blocks if i == 0 else 1): for j in range(bottleneck_blocks if i == 0 else 1):
@@ -199,11 +203,11 @@ class StageA(nn.Module):
up_blocks.append(block) up_blocks.append(block)
if i < levels - 1: if i < levels - 1:
up_blocks.append( up_blocks.append(
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
padding=1)) padding=1))
self.up_blocks = nn.Sequential(*up_blocks) self.up_blocks = nn.Sequential(*up_blocks)
self.out_block = nn.Sequential( self.out_block = nn.Sequential(
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
nn.PixelShuffle(2), nn.PixelShuffle(2),
) )
@@ -232,17 +236,17 @@ class Discriminator(nn.Module):
super().__init__() super().__init__()
d = max(depth - 3, 3) d = max(depth - 3, 3)
layers = [ layers = [
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
] ]
for i in range(depth - 1): for i in range(depth - 1):
c_in = c_hidden // (2 ** max((d - i), 0)) c_in = c_hidden // (2 ** max((d - i), 0))
c_out = c_hidden // (2 ** max((d - 1 - i), 0)) c_out = c_hidden // (2 ** max((d - 1 - i), 0))
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
layers.append(nn.InstanceNorm2d(c_out)) layers.append(nn.InstanceNorm2d(c_out))
layers.append(nn.LeakyReLU(0.2)) layers.append(nn.LeakyReLU(0.2))
self.encoder = nn.Sequential(*layers) self.encoder = nn.Sequential(*layers)
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
self.logits = nn.Sigmoid() self.logits = nn.Sigmoid()
def forward(self, x, cond=None): def forward(self, x, cond=None):

View File

@@ -19,6 +19,9 @@ import torch
import torchvision import torchvision
from torch import nn from torch import nn
import comfy.ops
ops = comfy.ops.disable_weight_init
# EfficientNet # EfficientNet
class EfficientNetEncoder(nn.Module): class EfficientNetEncoder(nn.Module):
@@ -26,7 +29,7 @@ class EfficientNetEncoder(nn.Module):
super().__init__() super().__init__()
self.backbone = torchvision.models.efficientnet_v2_s().features.eval() self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
self.mapper = nn.Sequential( self.mapper = nn.Sequential(
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), ops.Conv2d(1280, c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
) )
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406])) self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
@@ -34,7 +37,7 @@ class EfficientNetEncoder(nn.Module):
def forward(self, x): def forward(self, x):
x = x * 0.5 + 0.5 x = x * 0.5 + 0.5
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1]) x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype)
o = self.mapper(self.backbone(x)) o = self.mapper(self.backbone(x))
return o return o
@@ -44,39 +47,39 @@ class Previewer(nn.Module):
def __init__(self, c_in=16, c_hidden=512, c_out=3): def __init__(self, c_in=16, c_hidden=512, c_out=3):
super().__init__() super().__init__()
self.blocks = nn.Sequential( self.blocks = nn.Sequential(
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden), nn.BatchNorm2d(c_hidden),
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden), nn.BatchNorm2d(c_hidden),
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden // 2), nn.BatchNorm2d(c_hidden // 2),
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden // 2), nn.BatchNorm2d(c_hidden // 2),
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden // 4), nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden // 4), nn.BatchNorm2d(c_hidden // 4),
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden // 4), nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(), nn.GELU(),
nn.BatchNorm2d(c_hidden // 4), nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), ops.Conv2d(c_hidden // 4, c_out, kernel_size=1),
) )
def forward(self, x): def forward(self, x):

View File

@@ -105,7 +105,9 @@ class Modulation(nn.Module):
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple: def forward(self, vec: Tensor) -> tuple:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) if vec.ndim == 2:
vec = vec[:, None, :]
out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
return ( return (
ModulationOut(*out[:3]), ModulationOut(*out[:3]),
@@ -113,6 +115,20 @@ class Modulation(nn.Module):
) )
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
if modulation_dims is None:
if m_add is not None:
return tensor * m_mult + m_add
else:
return tensor * m_mult
else:
for d in modulation_dims:
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
if m_add is not None:
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
return tensor
class DoubleStreamBlock(nn.Module): class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
@@ -143,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
) )
self.flipped_img_txt = flipped_img_txt self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None): def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
img_mod1, img_mod2 = self.img_mod(vec) img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention # prepare image for attention
img_modulated = self.img_norm1(img) img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
img_qkv = self.img_attn.qkv(img_modulated) img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention # prepare txt for attention
txt_modulated = self.txt_norm1(txt) txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
txt_qkv = self.txt_attn.qkv(txt_modulated) txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
@@ -179,12 +195,12 @@ class DoubleStreamBlock(nn.Module):
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks # calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn) img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
# calculate the txt bloks # calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
if txt.dtype == torch.float16: if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
@@ -228,9 +244,9 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh") self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor: def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
mod, _ = self.modulation(vec) mod, _ = self.modulation(vec)
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
@@ -239,7 +255,7 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask) attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16: if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x return x
@@ -252,8 +268,11 @@ class LastLayer(nn.Module):
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
def forward(self, x: Tensor, vec: Tensor) -> Tensor: def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) if vec.ndim == 2:
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] vec = vec[:, None, :]
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims)
x = self.linear(x) x = self.linear(x)
return x return x

View File

@@ -10,8 +10,9 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q_shape = q.shape q_shape = q.shape
k_shape = k.shape k_shape = k.shape
q = q.float().reshape(*q.shape[:-1], -1, 1, 2) if pe is not None:
k = k.float().reshape(*k.shape[:-1], -1, 1, 2) q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
@@ -36,8 +37,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -115,8 +115,11 @@ class Flux(nn.Module):
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt) txt = self.txt_in(txt)
if img_ids is not None:
ids = torch.cat((txt_ids, img_ids), dim=1) ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids) pe = self.pe_embedder(ids)
else:
pe = None
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks): for i, block in enumerate(self.double_blocks):

View File

@@ -227,6 +227,7 @@ class HunyuanVideo(nn.Module):
timesteps: Tensor, timesteps: Tensor,
y: Tensor, y: Tensor,
guidance: Tensor = None, guidance: Tensor = None,
guiding_frame_index=None,
control=None, control=None,
transformer_options={}, transformer_options={},
) -> Tensor: ) -> Tensor:
@@ -237,7 +238,17 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img) img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
modulation_dims_txt = [(0, None, 1)]
else:
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
modulation_dims = None
modulation_dims_txt = None
if self.params.guidance_embed: if self.params.guidance_embed:
if guidance is not None: if guidance is not None:
@@ -264,14 +275,14 @@ class HunyuanVideo(nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"]) out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
return out return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
txt = out["txt"] txt = out["txt"]
img = out["img"] img = out["img"]
else: else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
if control is not None: # Controlnet if control is not None: # Controlnet
control_i = control.get("input") control_i = control.get("input")
@@ -286,13 +297,13 @@ class HunyuanVideo(nn.Module):
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"]) out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
return out return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap}) out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
img = out["img"] img = out["img"]
else: else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
if control is not None: # Controlnet if control is not None: # Controlnet
control_o = control.get("output") control_o = control.get("output")
@@ -303,7 +314,7 @@ class HunyuanVideo(nn.Module):
img = img[:, : img_len] img = img[:, : img_len]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
shape = initial_shape[-3:] shape = initial_shape[-3:]
for i in range(len(shape)): for i in range(len(shape)):
@@ -313,7 +324,7 @@ class HunyuanVideo(nn.Module):
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
return img return img
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs): def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape bs, c, t, h, w = x.shape
patch_size = self.patch_size patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
@@ -325,5 +336,5 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options) out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
return out return out

View File

@@ -7,7 +7,7 @@ from einops import rearrange
import math import math
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
def get_timestep_embedding( def get_timestep_embedding(
@@ -377,12 +377,16 @@ class LTXVModel(torch.nn.Module):
positional_embedding_theta=10000.0, positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048], positional_embedding_max_pos=[20, 2048, 2048],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
dtype=None, device=None, operations=None, **kwargs): dtype=None, device=None, operations=None, **kwargs):
super().__init__() super().__init__()
self.generator = None self.generator = None
self.vae_scale_factors = vae_scale_factors
self.dtype = dtype self.dtype = dtype
self.out_channels = in_channels self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
@@ -416,42 +420,23 @@ class LTXVModel(torch.nn.Module):
self.patchifier = SymmetricPatchifier(1) self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs): def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2],
orig_height=x.shape[3],
orig_width=x.shape[4],
batch_size=x.shape[0],
scale_grid=((1 / frame_rate) * 8, 32, 32),
device=x.device,
)
if guiding_latent is not None:
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
ts *= input_ts
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
timestep = self.patchifier.patchify(ts)
input_x = x.clone()
x[:, :, 0] = guiding_latent[:, :, 0]
if guiding_latent_noise_scale > 0:
if self.generator is None:
self.generator = torch.Generator(device=x.device).manual_seed(42)
elif self.generator.device != x.device:
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
scale = guiding_latent_noise_scale * (input_ts ** 2)
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
orig_shape = list(x.shape) orig_shape = list(x.shape)
x = self.patchifier.patchify(x) x, latent_coords = self.patchifier.patchify(x)
pixel_coords = latent_to_pixel_coords(
latent_coords=latent_coords,
scale_factors=self.vae_scale_factors,
causal_fix=self.causal_temporal_positioning,
)
if keyframe_idxs is not None:
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
fractional_coords = pixel_coords.to(torch.float32)
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
x = self.patchify_proj(x) x = self.patchify_proj(x)
timestep = timestep * 1000.0 timestep = timestep * 1000.0
@@ -459,7 +444,7 @@ class LTXVModel(torch.nn.Module):
if attention_mask is not None and not torch.is_floating_point(attention_mask): if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype) pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0] batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single( timestep, embedded_timestep = self.adaln_single(
@@ -519,8 +504,4 @@ class LTXVModel(torch.nn.Module):
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
) )
if guiding_latent is not None:
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
# print("res", x)
return x return x

View File

@@ -6,16 +6,29 @@ from einops import rearrange
from torch import Tensor from torch import Tensor
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: def latent_to_pixel_coords(
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False
dims_to_append = target_dims - x.ndim ) -> Tensor:
if dims_to_append < 0: """
raise ValueError( Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" configuration.
Args:
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
containing the latent corner coordinates of each token.
scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space.
causal_fix (bool): Whether to take into account the different temporal scale
of the first frame. Default = False for backwards compatibility.
Returns:
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
"""
pixel_coords = (
latent_coords
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
) )
elif dims_to_append == 0: if causal_fix:
return x # Fix temporal scale for first frame to 1 due to causality
return x[(...,) + (None,) * dims_to_append] pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords
class Patchifier(ABC): class Patchifier(ABC):
@@ -44,29 +57,26 @@ class Patchifier(ABC):
def patch_size(self): def patch_size(self):
return self._patch_size return self._patch_size
def get_grid( def get_latent_coords(
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device self, latent_num_frames, latent_height, latent_width, batch_size, device
): ):
f = orig_num_frames // self._patch_size[0] """
h = orig_height // self._patch_size[1] Return a tensor of shape [batch_size, 3, num_patches] containing the
w = orig_width // self._patch_size[2] top-left corner latent coordinates of each latent patch.
grid_h = torch.arange(h, dtype=torch.float32, device=device) The tensor is repeated for each batch element.
grid_w = torch.arange(w, dtype=torch.float32, device=device) """
grid_f = torch.arange(f, dtype=torch.float32, device=device) latent_sample_coords = torch.meshgrid(
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij') torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
grid = torch.stack(grid, dim=0) torch.arange(0, latent_height, self._patch_size[1], device=device),
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) torch.arange(0, latent_width, self._patch_size[2], device=device),
indexing="ij",
if scale_grid is not None: )
for i in range(3): latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
if isinstance(scale_grid[i], Tensor): latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
scale = append_dims(scale_grid[i], grid.ndim - 1) latent_coords = rearrange(
else: latent_coords, "b c f h w -> b c (f h w)", b=batch_size
scale = scale_grid[i] )
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i] return latent_coords
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
return grid
class SymmetricPatchifier(Patchifier): class SymmetricPatchifier(Patchifier):
@@ -74,6 +84,8 @@ class SymmetricPatchifier(Patchifier):
self, self,
latents: Tensor, latents: Tensor,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
b, _, f, h, w = latents.shape
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
latents = rearrange( latents = rearrange(
latents, latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
@@ -81,7 +93,7 @@ class SymmetricPatchifier(Patchifier):
p2=self._patch_size[1], p2=self._patch_size[1],
p3=self._patch_size[2], p3=self._patch_size[2],
) )
return latents return latents, latent_coords
def unpatchify( def unpatchify(
self, self,

View File

@@ -15,6 +15,7 @@ class CausalConv3d(nn.Module):
stride: Union[int, Tuple[int]] = 1, stride: Union[int, Tuple[int]] = 1,
dilation: int = 1, dilation: int = 1,
groups: int = 1, groups: int = 1,
spatial_padding_mode: str = "zeros",
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
@@ -38,7 +39,7 @@ class CausalConv3d(nn.Module):
stride=stride, stride=stride,
dilation=dilation, dilation=dilation,
padding=padding, padding=padding,
padding_mode="zeros", padding_mode=spatial_padding_mode,
groups=groups, groups=groups,
) )

View File

@@ -1,13 +1,15 @@
from __future__ import annotations
import torch import torch
from torch import nn from torch import nn
from functools import partial from functools import partial
import math import math
from einops import rearrange from einops import rearrange
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd from .conv_nd_factory import make_conv_nd, make_linear_nd
from .pixel_norm import PixelNorm from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
class Encoder(nn.Module): class Encoder(nn.Module):
@@ -32,7 +34,7 @@ class Encoder(nn.Module):
norm_layer (`str`, *optional*, defaults to `group_norm`): norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`. The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var (`str`, *optional*, defaults to `per_channel`): latent_log_var (`str`, *optional*, defaults to `per_channel`):
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
""" """
def __init__( def __init__(
@@ -40,12 +42,13 @@ class Encoder(nn.Module):
dims: Union[int, Tuple[int, int]] = 3, dims: Union[int, Tuple[int, int]] = 3,
in_channels: int = 3, in_channels: int = 3,
out_channels: int = 3, out_channels: int = 3,
blocks=[("res_x", 1)], blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
base_channels: int = 128, base_channels: int = 128,
norm_num_groups: int = 32, norm_num_groups: int = 32,
patch_size: Union[int, Tuple[int]] = 1, patch_size: Union[int, Tuple[int]] = 1,
norm_layer: str = "group_norm", # group_norm, pixel_norm norm_layer: str = "group_norm", # group_norm, pixel_norm
latent_log_var: str = "per_channel", latent_log_var: str = "per_channel",
spatial_padding_mode: str = "zeros",
): ):
super().__init__() super().__init__()
self.patch_size = patch_size self.patch_size = patch_size
@@ -65,6 +68,7 @@ class Encoder(nn.Module):
stride=1, stride=1,
padding=1, padding=1,
causal=True, causal=True,
spatial_padding_mode=spatial_padding_mode,
) )
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
@@ -82,6 +86,7 @@ class Encoder(nn.Module):
resnet_eps=1e-6, resnet_eps=1e-6,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
norm_layer=norm_layer, norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
) )
elif block_name == "res_x_y": elif block_name == "res_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel output_channel = block_params.get("multiplier", 2) * output_channel
@@ -92,6 +97,7 @@ class Encoder(nn.Module):
eps=1e-6, eps=1e-6,
groups=norm_num_groups, groups=norm_num_groups,
norm_layer=norm_layer, norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
) )
elif block_name == "compress_time": elif block_name == "compress_time":
block = make_conv_nd( block = make_conv_nd(
@@ -101,6 +107,7 @@ class Encoder(nn.Module):
kernel_size=3, kernel_size=3,
stride=(2, 1, 1), stride=(2, 1, 1),
causal=True, causal=True,
spatial_padding_mode=spatial_padding_mode,
) )
elif block_name == "compress_space": elif block_name == "compress_space":
block = make_conv_nd( block = make_conv_nd(
@@ -110,6 +117,7 @@ class Encoder(nn.Module):
kernel_size=3, kernel_size=3,
stride=(1, 2, 2), stride=(1, 2, 2),
causal=True, causal=True,
spatial_padding_mode=spatial_padding_mode,
) )
elif block_name == "compress_all": elif block_name == "compress_all":
block = make_conv_nd( block = make_conv_nd(
@@ -119,6 +127,7 @@ class Encoder(nn.Module):
kernel_size=3, kernel_size=3,
stride=(2, 2, 2), stride=(2, 2, 2),
causal=True, causal=True,
spatial_padding_mode=spatial_padding_mode,
) )
elif block_name == "compress_all_x_y": elif block_name == "compress_all_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel output_channel = block_params.get("multiplier", 2) * output_channel
@@ -129,6 +138,34 @@ class Encoder(nn.Module):
kernel_size=3, kernel_size=3,
stride=(2, 2, 2), stride=(2, 2, 2),
causal=True, causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(2, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
) )
else: else:
raise ValueError(f"unknown block: {block_name}") raise ValueError(f"unknown block: {block_name}")
@@ -152,10 +189,18 @@ class Encoder(nn.Module):
conv_out_channels *= 2 conv_out_channels *= 2
elif latent_log_var == "uniform": elif latent_log_var == "uniform":
conv_out_channels += 1 conv_out_channels += 1
elif latent_log_var == "constant":
conv_out_channels += 1
elif latent_log_var != "none": elif latent_log_var != "none":
raise ValueError(f"Invalid latent_log_var: {latent_log_var}") raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd( self.conv_out = make_conv_nd(
dims, output_channel, conv_out_channels, 3, padding=1, causal=True dims,
output_channel,
conv_out_channels,
3,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@@ -197,6 +242,15 @@ class Encoder(nn.Module):
sample = torch.cat([sample, repeated_last_channel], dim=1) sample = torch.cat([sample, repeated_last_channel], dim=1)
else: else:
raise ValueError(f"Invalid input shape: {sample.shape}") raise ValueError(f"Invalid input shape: {sample.shape}")
elif self.latent_log_var == "constant":
sample = sample[:, :-1, ...]
approx_ln_0 = (
-30
) # this is the minimal clamp value in DiagonalGaussianDistribution objects
sample = torch.cat(
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
dim=1,
)
return sample return sample
@@ -231,7 +285,7 @@ class Decoder(nn.Module):
dims, dims,
in_channels: int = 3, in_channels: int = 3,
out_channels: int = 3, out_channels: int = 3,
blocks=[("res_x", 1)], blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
base_channels: int = 128, base_channels: int = 128,
layers_per_block: int = 2, layers_per_block: int = 2,
norm_num_groups: int = 32, norm_num_groups: int = 32,
@@ -239,6 +293,7 @@ class Decoder(nn.Module):
norm_layer: str = "group_norm", norm_layer: str = "group_norm",
causal: bool = True, causal: bool = True,
timestep_conditioning: bool = False, timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
): ):
super().__init__() super().__init__()
self.patch_size = patch_size self.patch_size = patch_size
@@ -264,6 +319,7 @@ class Decoder(nn.Module):
stride=1, stride=1,
padding=1, padding=1,
causal=True, causal=True,
spatial_padding_mode=spatial_padding_mode,
) )
self.up_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
@@ -283,6 +339,7 @@ class Decoder(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False), inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning, timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
) )
elif block_name == "attn_res_x": elif block_name == "attn_res_x":
block = UNetMidBlock3D( block = UNetMidBlock3D(
@@ -294,6 +351,7 @@ class Decoder(nn.Module):
inject_noise=block_params.get("inject_noise", False), inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning, timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"], attention_head_dim=block_params["attention_head_dim"],
spatial_padding_mode=spatial_padding_mode,
) )
elif block_name == "res_x_y": elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2) output_channel = output_channel // block_params.get("multiplier", 2)
@@ -306,14 +364,21 @@ class Decoder(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False), inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=False, timestep_conditioning=False,
spatial_padding_mode=spatial_padding_mode,
) )
elif block_name == "compress_time": elif block_name == "compress_time":
block = DepthToSpaceUpsample( block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(2, 1, 1) dims=dims,
in_channels=input_channel,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
) )
elif block_name == "compress_space": elif block_name == "compress_space":
block = DepthToSpaceUpsample( block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(1, 2, 2) dims=dims,
in_channels=input_channel,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
) )
elif block_name == "compress_all": elif block_name == "compress_all":
output_channel = output_channel // block_params.get("multiplier", 1) output_channel = output_channel // block_params.get("multiplier", 1)
@@ -323,6 +388,7 @@ class Decoder(nn.Module):
stride=(2, 2, 2), stride=(2, 2, 2),
residual=block_params.get("residual", False), residual=block_params.get("residual", False),
out_channels_reduction_factor=block_params.get("multiplier", 1), out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
) )
else: else:
raise ValueError(f"unknown layer: {block_name}") raise ValueError(f"unknown layer: {block_name}")
@@ -340,7 +406,13 @@ class Decoder(nn.Module):
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd( self.conv_out = make_conv_nd(
dims, output_channel, out_channels, 3, padding=1, causal=True dims,
output_channel,
out_channels,
3,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@@ -433,6 +505,12 @@ class UNetMidBlock3D(nn.Module):
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32): resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks. The number of groups to use in the group normalization layers of the resnet blocks.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
inject_noise (`bool`, *optional*, defaults to `False`):
Whether to inject noise into the hidden states.
timestep_conditioning (`bool`, *optional*, defaults to `False`):
Whether to condition the hidden states on the timestep.
Returns: Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
@@ -451,6 +529,7 @@ class UNetMidBlock3D(nn.Module):
norm_layer: str = "group_norm", norm_layer: str = "group_norm",
inject_noise: bool = False, inject_noise: bool = False,
timestep_conditioning: bool = False, timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
): ):
super().__init__() super().__init__()
resnet_groups = ( resnet_groups = (
@@ -476,13 +555,17 @@ class UNetMidBlock3D(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
inject_noise=inject_noise, inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning, timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
) )
for _ in range(num_layers) for _ in range(num_layers)
] ]
) )
def forward( def forward(
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None self,
hidden_states: torch.FloatTensor,
causal: bool = True,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
timestep_embed = None timestep_embed = None
if self.timestep_conditioning: if self.timestep_conditioning:
@@ -507,9 +590,62 @@ class UNetMidBlock3D(nn.Module):
return hidden_states return hidden_states
class SpaceToDepthDownsample(nn.Module):
def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
super().__init__()
self.stride = stride
self.group_size = in_channels * math.prod(stride) // out_channels
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=out_channels // math.prod(stride),
kernel_size=3,
stride=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
def forward(self, x, causal: bool = True):
if self.stride[0] == 2:
x = torch.cat(
[x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding
# skip connection
x_in = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
x_in = x_in.mean(dim=2)
# conv
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x = x + x_in
return x
class DepthToSpaceUpsample(nn.Module): class DepthToSpaceUpsample(nn.Module):
def __init__( def __init__(
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1 self,
dims,
in_channels,
stride,
residual=False,
out_channels_reduction_factor=1,
spatial_padding_mode="zeros",
): ):
super().__init__() super().__init__()
self.stride = stride self.stride = stride
@@ -523,6 +659,7 @@ class DepthToSpaceUpsample(nn.Module):
kernel_size=3, kernel_size=3,
stride=1, stride=1,
causal=True, causal=True,
spatial_padding_mode=spatial_padding_mode,
) )
self.residual = residual self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor self.out_channels_reduction_factor = out_channels_reduction_factor
@@ -558,7 +695,7 @@ class DepthToSpaceUpsample(nn.Module):
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None: def __init__(self, dim, eps, elementwise_affine=True) -> None:
super().__init__() super().__init__()
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) self.norm = ops.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, x): def forward(self, x):
x = rearrange(x, "b c d h w -> b d h w c") x = rearrange(x, "b c d h w -> b d h w c")
@@ -591,6 +728,7 @@ class ResnetBlock3D(nn.Module):
norm_layer: str = "group_norm", norm_layer: str = "group_norm",
inject_noise: bool = False, inject_noise: bool = False,
timestep_conditioning: bool = False, timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@@ -617,6 +755,7 @@ class ResnetBlock3D(nn.Module):
stride=1, stride=1,
padding=1, padding=1,
causal=True, causal=True,
spatial_padding_mode=spatial_padding_mode,
) )
if inject_noise: if inject_noise:
@@ -641,6 +780,7 @@ class ResnetBlock3D(nn.Module):
stride=1, stride=1,
padding=1, padding=1,
causal=True, causal=True,
spatial_padding_mode=spatial_padding_mode,
) )
if inject_noise: if inject_noise:
@@ -801,9 +941,44 @@ class processor(nn.Module):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x) return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module): class VideoVAE(nn.Module):
def __init__(self, version=0): def __init__(self, version=0, config=None):
super().__init__() super().__init__()
if config is None:
config = self.guess_config(version)
self.timestep_conditioning = config.get("timestep_conditioning", False)
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
self.encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
)
self.decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=self.timestep_conditioning,
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
)
self.per_channel_statistics = processor()
def guess_config(self, version):
if version == 0: if version == 0:
config = { config = {
"_class_name": "CausalVideoAutoencoder", "_class_name": "CausalVideoAutoencoder",
@@ -830,7 +1005,7 @@ class VideoVAE(nn.Module):
"use_quant_conv": False, "use_quant_conv": False,
"causal_decoder": False, "causal_decoder": False,
} }
else: elif version == 1:
config = { config = {
"_class_name": "CausalVideoAutoencoder", "_class_name": "CausalVideoAutoencoder",
"dims": 3, "dims": 3,
@@ -866,37 +1041,47 @@ class VideoVAE(nn.Module):
"causal_decoder": False, "causal_decoder": False,
"timestep_conditioning": True, "timestep_conditioning": True,
} }
else:
double_z = config.get("double_z", True) config = {
latent_log_var = config.get( "_class_name": "CausalVideoAutoencoder",
"latent_log_var", "per_channel" if double_z else "none" "dims": 3,
) "in_channels": 3,
"out_channels": 3,
self.encoder = Encoder( "latent_channels": 128,
dims=config["dims"], "encoder_blocks": [
in_channels=config.get("in_channels", 3), ["res_x", {"num_layers": 4}],
out_channels=config["latent_channels"], ["compress_space_res", {"multiplier": 2}],
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))), ["res_x", {"num_layers": 6}],
patch_size=config.get("patch_size", 1), ["compress_time_res", {"multiplier": 2}],
latent_log_var=latent_log_var, ["res_x", {"num_layers": 6}],
norm_layer=config.get("norm_layer", "group_norm"), ["compress_all_res", {"multiplier": 2}],
) ["res_x", {"num_layers": 2}],
["compress_all_res", {"multiplier": 2}],
self.decoder = Decoder( ["res_x", {"num_layers": 2}]
dims=config["dims"], ],
in_channels=config["latent_channels"], "decoder_blocks": [
out_channels=config.get("out_channels", 3), ["res_x", {"num_layers": 5, "inject_noise": False}],
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))), ["compress_all", {"residual": True, "multiplier": 2}],
patch_size=config.get("patch_size", 1), ["res_x", {"num_layers": 5, "inject_noise": False}],
norm_layer=config.get("norm_layer", "group_norm"), ["compress_all", {"residual": True, "multiplier": 2}],
causal=config.get("causal_decoder", False), ["res_x", {"num_layers": 5, "inject_noise": False}],
timestep_conditioning=config.get("timestep_conditioning", False), ["compress_all", {"residual": True, "multiplier": 2}],
) ["res_x", {"num_layers": 5, "inject_noise": False}]
],
self.timestep_conditioning = config.get("timestep_conditioning", False) "scaling_factor": 1.0,
self.per_channel_statistics = processor() "norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
"timestep_conditioning": True
}
return config
def encode(self, x): def encode(self, x):
frames_count = x.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
means, logvar = torch.chunk(self.encoder(x), 2, dim=1) means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means) return self.per_channel_statistics.normalize(means)

View File

@@ -17,7 +17,11 @@ def make_conv_nd(
groups=1, groups=1,
bias=True, bias=True,
causal=False, causal=False,
spatial_padding_mode="zeros",
temporal_padding_mode="zeros",
): ):
if not (spatial_padding_mode == temporal_padding_mode or causal):
raise NotImplementedError("spatial and temporal padding modes must be equal")
if dims == 2: if dims == 2:
return ops.Conv2d( return ops.Conv2d(
in_channels=in_channels, in_channels=in_channels,
@@ -28,6 +32,7 @@ def make_conv_nd(
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias, bias=bias,
padding_mode=spatial_padding_mode,
) )
elif dims == 3: elif dims == 3:
if causal: if causal:
@@ -40,6 +45,7 @@ def make_conv_nd(
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias, bias=bias,
spatial_padding_mode=spatial_padding_mode,
) )
return ops.Conv3d( return ops.Conv3d(
in_channels=in_channels, in_channels=in_channels,
@@ -50,6 +56,7 @@ def make_conv_nd(
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias, bias=bias,
padding_mode=spatial_padding_mode,
) )
elif dims == (2, 1): elif dims == (2, 1):
return DualConv3d( return DualConv3d(
@@ -59,6 +66,7 @@ def make_conv_nd(
stride=stride, stride=stride,
padding=padding, padding=padding,
bias=bias, bias=bias,
padding_mode=spatial_padding_mode,
) )
else: else:
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f"unsupported dimensions: {dims}")

View File

@@ -18,11 +18,13 @@ class DualConv3d(nn.Module):
dilation: Union[int, Tuple[int, int, int]] = 1, dilation: Union[int, Tuple[int, int, int]] = 1,
groups=1, groups=1,
bias=True, bias=True,
padding_mode="zeros",
): ):
super(DualConv3d, self).__init__() super(DualConv3d, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.padding_mode = padding_mode
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3 # Ensure kernel_size, stride, padding, and dilation are tuples of length 3
if isinstance(kernel_size, int): if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size) kernel_size = (kernel_size, kernel_size, kernel_size)
@@ -108,6 +110,7 @@ class DualConv3d(nn.Module):
self.padding1, self.padding1,
self.dilation1, self.dilation1,
self.groups, self.groups,
padding_mode=self.padding_mode,
) )
if skip_time_conv: if skip_time_conv:
@@ -122,6 +125,7 @@ class DualConv3d(nn.Module):
self.padding2, self.padding2,
self.dilation2, self.dilation2,
self.groups, self.groups,
padding_mode=self.padding_mode,
) )
return x return x
@@ -137,7 +141,16 @@ class DualConv3d(nn.Module):
stride1 = (self.stride1[1], self.stride1[2]) stride1 = (self.stride1[1], self.stride1[2])
padding1 = (self.padding1[1], self.padding1[2]) padding1 = (self.padding1[1], self.padding1[2])
dilation1 = (self.dilation1[1], self.dilation1[2]) dilation1 = (self.dilation1[1], self.dilation1[2])
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups) x = F.conv2d(
x,
weight1,
self.bias1,
stride1,
padding1,
dilation1,
self.groups,
padding_mode=self.padding_mode,
)
_, _, h, w = x.shape _, _, h, w = x.shape
@@ -154,7 +167,16 @@ class DualConv3d(nn.Module):
stride2 = self.stride2[0] stride2 = self.stride2[0]
padding2 = self.padding2[0] padding2 = self.padding2[0]
dilation2 = self.dilation2[0] dilation2 = self.dilation2[0]
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups) x = F.conv1d(
x,
weight2,
self.bias2,
stride2,
padding2,
dilation2,
self.groups,
padding_mode=self.padding_mode,
)
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x return x

View File

@@ -24,6 +24,13 @@ if model_management.sage_attention_enabled():
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
exit(-1) exit(-1)
if model_management.flash_attention_enabled():
try:
from flash_attn import flash_attn_func
except ModuleNotFoundError:
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
from comfy.cli_args import args from comfy.cli_args import args
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@@ -496,6 +503,63 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
return out return out
try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
@flash_attn_wrapper.register_fake
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
# Output shape is the same as q
return q.new_empty(q.shape)
except AttributeError as error:
FLASH_ATTN_ERROR = error
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
try:
assert mask is None
out = flash_attn_wrapper(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
dropout_p=0.0,
causal=False,
).transpose(1, 2)
except Exception as e:
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
optimized_attention = attention_basic optimized_attention = attention_basic
if model_management.sage_attention_enabled(): if model_management.sage_attention_enabled():
@@ -504,6 +568,9 @@ if model_management.sage_attention_enabled():
elif model_management.xformers_enabled(): elif model_management.xformers_enabled():
logging.info("Using xformers attention") logging.info("Using xformers attention")
optimized_attention = attention_xformers optimized_attention = attention_xformers
elif model_management.flash_attention_enabled():
logging.info("Using Flash Attention")
optimized_attention = attention_flash
elif model_management.pytorch_attention_enabled(): elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention") logging.info("Using pytorch attention")
optimized_attention = attention_pytorch optimized_attention = attention_pytorch

View File

@@ -384,6 +384,7 @@ class WanModel(torch.nn.Module):
context, context,
clip_fea=None, clip_fea=None,
freqs=None, freqs=None,
transformer_options={},
): ):
r""" r"""
Forward pass through the diffusion model Forward pass through the diffusion model
@@ -423,14 +424,18 @@ class WanModel(torch.nn.Module):
context_clip = self.img_emb(clip_fea) # bs x 257 x dim context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1) context = torch.concat([context_clip, context], dim=1)
# arguments patches_replace = transformer_options.get("patches_replace", {})
kwargs = dict( blocks_replace = patches_replace.get("dit", {})
e=e0, for i, block in enumerate(self.blocks):
freqs=freqs, if ("double_block", i) in blocks_replace:
context=context) def block_wrap(args):
out = {}
for block in self.blocks: out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
x = block(x, **kwargs) return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
# head # head
x = self.head(x, e) x = self.head(x, e)
@@ -439,7 +444,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return x return x
def forward(self, x, timestep, context, clip_fea=None, **kwargs): def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
bs, c, t, h, w = x.shape bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size patch_size = self.patch_size
@@ -453,7 +458,7 @@ class WanModel(torch.nn.Module):
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2) freqs = self.rope_embedder(img_ids).movedim(1, 2)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w] return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes): def unpatchify(self, x, grid_sizes):
r""" r"""

View File

@@ -108,7 +108,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None: if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None) fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
else: else:
operations = model_config.custom_operations operations = model_config.custom_operations
@@ -161,9 +161,13 @@ class BaseModel(torch.nn.Module):
extra = extra.to(dtype) extra = extra.to(dtype)
extra_conds[o] = extra extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x) return self.model_sampling.calculate_denoised(sigma, model_output, x)
def process_timestep(self, timestep, **kwargs):
return timestep
def get_dtype(self): def get_dtype(self):
return self.diffusion_model.dtype return self.diffusion_model.dtype
@@ -185,6 +189,11 @@ class BaseModel(torch.nn.Module):
if concat_latent_image.shape[1:] != noise.shape[1:]: if concat_latent_image.shape[1:] != noise.shape[1:]:
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
if noise.ndim == 5:
if concat_latent_image.shape[-3] < noise.shape[-3]:
concat_latent_image = torch.nn.functional.pad(concat_latent_image, (0, 0, 0, 0, 0, noise.shape[-3] - concat_latent_image.shape[-3]), "constant", 0)
else:
concat_latent_image = concat_latent_image[:, :, :noise.shape[-3]]
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
@@ -213,6 +222,11 @@ class BaseModel(torch.nn.Module):
cond_concat.append(self.blank_inpaint_image_like(noise)) cond_concat.append(self.blank_inpaint_image_like(noise))
elif ck == "mask_inverted": elif ck == "mask_inverted":
cond_concat.append(torch.zeros_like(noise)[:, :1]) cond_concat.append(torch.zeros_like(noise)[:, :1])
if ck == "concat_image":
if concat_latent_image is not None:
cond_concat.append(concat_latent_image.to(device))
else:
cond_concat.append(torch.zeros_like(noise))
data = torch.cat(cond_concat, dim=1) data = torch.cat(cond_concat, dim=1)
return data return data
return None return None
@@ -845,17 +859,26 @@ class LTXV(BaseModel):
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
guiding_latent = kwargs.get("guiding_latent", None)
if guiding_latent is not None:
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None)
if guiding_latent_noise_scale is not None:
out["guiding_latent_noise_scale"] = comfy.conds.CONDConstant(guiding_latent_noise_scale)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
keyframe_idxs = kwargs.get("keyframe_idxs", None)
if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
return out return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None:
return timestep
return self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class HunyuanVideo(BaseModel): class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
@@ -872,20 +895,35 @@ class HunyuanVideo(BaseModel):
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
if image is not None:
padding_shape = (noise.shape[0], 16, noise.shape[2] - 1, noise.shape[3], noise.shape[4])
latent_padding = torch.zeros(padding_shape, device=noise.device, dtype=noise.dtype)
image_latents = torch.cat([image.to(noise), latent_padding], dim=2)
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_latents))
guidance = kwargs.get("guidance", 6.0) guidance = kwargs.get("guidance", 6.0)
if guidance is not None: if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
guiding_frame_index = kwargs.get("guiding_frame_index", None)
if guiding_frame_index is not None:
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
return out return out
def scale_latent_inpaint(self, latent_image, **kwargs):
return latent_image
class HunyuanVideoI2V(HunyuanVideo):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
self.concat_keys = ("concat_image", "mask_inverted")
def scale_latent_inpaint(self, latent_image, **kwargs):
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
self.concat_keys = ("concat_image",)
def scale_latent_inpaint(self, latent_image, **kwargs):
return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
class CosmosVideo(BaseModel): class CosmosVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)
@@ -935,11 +973,11 @@ class WAN21(BaseModel):
self.image_to_video = image_to_video self.image_to_video = image_to_video
def concat_cond(self, **kwargs): def concat_cond(self, **kwargs):
if not self.image_to_video: noise = kwargs.get("noise", None)
if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]:
return None return None
image = kwargs.get("concat_latent_image", None) image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"] device = kwargs["device"]
if image is None: if image is None:
@@ -949,6 +987,9 @@ class WAN21(BaseModel):
image = self.process_latent_in(image) image = self.process_latent_in(image)
image = utils.resize_to_batch_size(image, noise.shape[0]) image = utils.resize_to_batch_size(image, noise.shape[0])
if not self.image_to_video:
return image
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None: if mask is None:
mask = torch.zeros_like(noise)[:, :4] mask = torch.zeros_like(noise)[:, :4]

View File

@@ -1,3 +1,4 @@
import json
import comfy.supported_models import comfy.supported_models
import comfy.supported_models_base import comfy.supported_models_base
import comfy.utils import comfy.utils
@@ -33,7 +34,7 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
return None return None
def detect_unet_config(state_dict, key_prefix): def detect_unet_config(state_dict, key_prefix, metadata=None):
state_dict_keys = list(state_dict.keys()) state_dict_keys = list(state_dict.keys())
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
@@ -210,6 +211,8 @@ def detect_unet_config(state_dict, key_prefix):
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {} dit_config = {}
dit_config["image_model"] = "ltxv" dit_config["image_model"] = "ltxv"
if metadata is not None and "config" in metadata:
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
return dit_config return dit_config
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
@@ -454,8 +457,8 @@ def model_config_from_unet_config(unet_config, state_dict=None):
logging.error("no match {}".format(unet_config)) logging.error("no match {}".format(unet_config))
return None return None
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None):
unet_config = detect_unet_config(state_dict, unet_key_prefix) unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
if unet_config is None: if unet_config is None:
return None return None
model_config = model_config_from_unet_config(unet_config, state_dict) model_config = model_config_from_unet_config(unet_config, state_dict)
@@ -468,6 +471,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
model_config.scaled_fp8 = scaled_fp8_weight.dtype model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32: if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn model_config.scaled_fp8 = torch.float8_e4m3fn
if scaled_fp8_weight.nelement() == 2:
model_config.optimizations["fp8"] = False
else:
model_config.optimizations["fp8"] = True
return model_config return model_config

View File

@@ -15,6 +15,7 @@
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
from __future__ import annotations
import psutil import psutil
import logging import logging
@@ -26,6 +27,10 @@ import platform
import weakref import weakref
import gc import gc
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
class VRAMState(Enum): class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram NO_VRAM = 1 #Very low vram: enable all the options to save vram
@@ -145,6 +150,25 @@ def get_torch_device():
else: else:
return torch.device(torch.cuda.current_device()) return torch.device(torch.cuda.current_device())
def get_all_torch_devices(exclude_current=False):
global cpu_state
devices = []
if cpu_state == CPUState.GPU:
if is_nvidia():
for i in range(torch.cuda.device_count()):
devices.append(torch.device(i))
elif is_intel_xpu():
for i in range(torch.xpu.device_count()):
devices.append(torch.device(i))
elif is_ascend_npu():
for i in range(torch.npu.device_count()):
devices.append(torch.device(i))
else:
devices.append(get_torch_device())
if exclude_current:
devices.remove(get_torch_device())
return devices
def get_total_memory(dev=None, torch_total_too=False): def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled global directml_enabled
if dev is None: if dev is None:
@@ -186,12 +210,21 @@ def get_total_memory(dev=None, torch_total_too=False):
else: else:
return mem_total return mem_total
def mac_version():
try:
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
except:
return None
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try: try:
logging.info("pytorch version: {}".format(torch_version)) logging.info("pytorch version: {}".format(torch_version))
mac_ver = mac_version()
if mac_ver is not None:
logging.info("Mac Version {}".format(mac_ver))
except: except:
pass pass
@@ -347,9 +380,13 @@ try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except: except:
logging.warning("Could not pick default device.") logging.warning("Could not pick default device.")
try:
for device in get_all_torch_devices(exclude_current=True):
logging.info("Device: {}".format(get_torch_device_name(device)))
except:
pass
current_loaded_models: list[LoadedModel] = []
current_loaded_models = []
def module_size(module): def module_size(module):
module_mem = 0 module_mem = 0
@@ -360,7 +397,7 @@ def module_size(module):
return module_mem return module_mem
class LoadedModel: class LoadedModel:
def __init__(self, model): def __init__(self, model: ModelPatcher):
self._set_model(model) self._set_model(model)
self.device = model.load_device self.device = model.load_device
self.real_model = None self.real_model = None
@@ -368,7 +405,7 @@ class LoadedModel:
self.model_finalizer = None self.model_finalizer = None
self._patcher_finalizer = None self._patcher_finalizer = None
def _set_model(self, model): def _set_model(self, model: ModelPatcher):
self._model = weakref.ref(model) self._model = weakref.ref(model)
if model.parent is not None: if model.parent is not None:
self._parent_model = weakref.ref(model.parent) self._parent_model = weakref.ref(model.parent)
@@ -581,7 +618,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
loaded_memory = loaded_model.model_loaded_memory() loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory) lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
@@ -921,6 +958,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
def sage_attention_enabled(): def sage_attention_enabled():
return args.use_sage_attention return args.use_sage_attention
def flash_attention_enabled():
return args.use_flash_attention
def xformers_enabled(): def xformers_enabled():
global directml_enabled global directml_enabled
global cpu_state global cpu_state
@@ -969,12 +1009,6 @@ def pytorch_attention_flash_attention():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
return False return False
def mac_version():
try:
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
except:
return None
def force_upcast_attention_dtype(): def force_upcast_attention_dtype():
upcast = args.force_upcast_attention upcast = args.force_upcast_attention
@@ -1213,6 +1247,31 @@ def soft_empty_cache(force=False):
def unload_all_models(): def unload_all_models():
free_memory(1e30, get_torch_device()) free_memory(1e30, get_torch_device())
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
'Unload only model and its clones - primarily for multigpu cloning purposes.'
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
additional_models = []
if unload_additional_models:
additional_models = model.get_nested_additional_models()
keep_loaded = []
for loaded_model in initial_keep_loaded:
if loaded_model.model is not None:
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
continue
# check additional models if they are a match
skip = False
for add_model in additional_models:
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
skip = True
break
if skip:
continue
keep_loaded.append(loaded_model)
if not all_devices:
free_memory(1e30, get_torch_device(), keep_loaded)
else:
for device in get_all_torch_devices():
free_memory(1e30, device, keep_loaded)
#TODO: might be cleaner to put this somewhere else #TODO: might be cleaner to put this somewhere else
import threading import threading

View File

@@ -84,12 +84,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
def create_model_options_clone(orig_model_options: dict): def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options) return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
def create_hook_patches_clone(orig_hook_patches): def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
new_hook_patches = {} new_hook_patches = {}
for hook_ref in orig_hook_patches: for hook_ref in orig_hook_patches:
new_hook_patches[hook_ref] = {} new_hook_patches[hook_ref] = {}
for k in orig_hook_patches[hook_ref]: for k in orig_hook_patches[hook_ref]:
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:] new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
if copy_tuples:
for i in range(len(new_hook_patches[hook_ref][k])):
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
return new_hook_patches return new_hook_patches
def wipe_lowvram_weight(m): def wipe_lowvram_weight(m):
@@ -240,6 +243,9 @@ class ModelPatcher:
self.is_clip = False self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
self.is_multigpu_base_clone = False
self.clone_base_uuid = uuid.uuid4()
if not hasattr(self.model, 'model_loaded_weight_memory'): if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0 self.model.model_loaded_weight_memory = 0
@@ -318,16 +324,90 @@ class ModelPatcher:
n.is_clip = self.is_clip n.is_clip = self.is_clip
n.hook_mode = self.hook_mode n.hook_mode = self.hook_mode
n.is_multigpu_base_clone = self.is_multigpu_base_clone
n.clone_base_uuid = self.clone_base_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n) callback(self, n)
return n return n
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
comfy.model_management.unload_model_and_clones(self)
n = self.clone()
# set load device, if present
if new_load_device is not None:
n.load_device = new_load_device
# unlike for normal clone, backup dicts that shared same ref should not;
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
n.backup = copy.deepcopy(n.backup)
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
n.hook_backup = copy.deepcopy(n.hook_backup)
n.model = copy.deepcopy(n.model)
# multigpu clone should not have multigpu additional_models entry
n.remove_additional_models("multigpu")
# multigpu_clone all stored additional_models; make sure circular references are properly handled
if models_cache is None:
models_cache = {}
for key, model_list in n.additional_models.items():
for i in range(len(model_list)):
add_model = n.additional_models[key][i]
if add_model.clone_base_uuid not in models_cache:
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
callback(self, n)
return n
def match_multigpu_clones(self):
multigpu_models = self.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
new_multigpu_models = []
for mm in multigpu_models:
# clone main model, but bring over relevant props from existing multigpu clone
n = self.clone()
n.load_device = mm.load_device
n.backup = mm.backup
n.object_patches_backup = mm.object_patches_backup
n.hook_backup = mm.hook_backup
n.model = mm.model
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
n.remove_additional_models("multigpu")
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
# figure out which additional models are not present in multigpu clone
models_cache = {}
for mm_add_model in mm.get_additional_models():
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
remove_models_uuids = set(list(models_cache.keys()))
for key, model_list in orig_additional_models.items():
for orig_add_model in model_list:
if orig_add_model.clone_base_uuid not in models_cache:
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
existing_list = n.get_additional_models_with_key(key)
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
n.set_additional_models(key, existing_list)
if orig_add_model.clone_base_uuid in remove_models_uuids:
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
# remove duplicate additional models
for key, model_list in n.additional_models.items():
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
n.set_additional_models(key, new_model_list)
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
callback(self, n)
new_multigpu_models.append(n)
self.set_additional_models("multigpu", new_multigpu_models)
def is_clone(self, other): def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model: if hasattr(other, 'model') and self.model is other.model:
return True return True
return False return False
def clone_has_same_weights(self, clone: 'ModelPatcher'): def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
if allow_multigpu:
if self.clone_base_uuid != clone.clone_base_uuid:
return False
else:
if not self.is_clone(clone): if not self.is_clone(clone):
return False return False
@@ -747,6 +827,7 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0): def partially_unload(self, device_to, memory_to_free=0):
with self.use_ejected(): with self.use_ejected():
hooks_unpatched = False
memory_freed = 0 memory_freed = 0
patch_counter = 0 patch_counter = 0
unload_list = self._load_list() unload_list = self._load_list()
@@ -770,6 +851,10 @@ class ModelPatcher:
move_weight = False move_weight = False
break break
if not hooks_unpatched:
self.unpatch_hooks()
hooks_unpatched = True
if bk.inplace_update: if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight) comfy.utils.copy_to_param(self.model, key, bk.weight)
else: else:
@@ -924,7 +1009,7 @@ class ModelPatcher:
return self.additional_models.get(key, []) return self.additional_models.get(key, [])
def get_additional_models(self): def get_additional_models(self):
all_models = [] all_models: list[ModelPatcher] = []
for models in self.additional_models.values(): for models in self.additional_models.values():
all_models.extend(models) all_models.extend(models)
return all_models return all_models
@@ -978,9 +1063,13 @@ class ModelPatcher:
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self) callback(self)
def prepare_state(self, timestep): def prepare_state(self, timestep, model_options, ignore_multigpu=False):
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep) callback(self, timestep, model_options, ignore_multigpu)
if not ignore_multigpu and "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p.prepare_state(timestep, model_options, ignore_multigpu=True)
def restore_hook_patches(self): def restore_hook_patches(self):
if self.hook_patches_backup is not None: if self.hook_patches_backup is not None:
@@ -993,12 +1082,18 @@ class ModelPatcher:
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]): def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
curr_t = t[0] curr_t = t[0]
reset_current_hooks = False reset_current_hooks = False
multigpu_kf_changed_cache = None
transformer_options = model_options.get("transformer_options", {}) transformer_options = model_options.get("transformer_options", {})
for hook in hook_group.hooks: for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options) changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
# this will cause the weights to be recalculated when sampling # this will cause the weights to be recalculated when sampling
if changed: if changed:
# cache changed for multigpu usage
if "multigpu_clones" in model_options:
if multigpu_kf_changed_cache is None:
multigpu_kf_changed_cache = []
multigpu_kf_changed_cache.append(hook)
# reset current_hooks if contains hook that changed # reset current_hooks if contains hook that changed
if self.current_hooks is not None: if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks: for current_hook in self.current_hooks.hooks:
@@ -1010,6 +1105,28 @@ class ModelPatcher:
self.cached_hook_patches.pop(cached_group) self.cached_hook_patches.pop(cached_group)
if reset_current_hooks: if reset_current_hooks:
self.patch_hooks(None) self.patch_hooks(None)
if "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
if kf_changed_cache is None:
return
reset_current_hooks = False
# reset current_hooks if contains hook that changed
for hook in kf_changed_cache:
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
if current_hook == hook:
reset_current_hooks = True
break
for cached_group in list(self.cached_hook_patches.keys()):
if cached_group.contains(hook):
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None, def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None): registered: comfy.hooks.HookGroup = None):
@@ -1089,7 +1206,6 @@ class ModelPatcher:
def patch_hooks(self, hooks: comfy.hooks.HookGroup): def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected(): with self.use_ejected():
self.unpatch_hooks()
if hooks is not None: if hooks is not None:
model_sd_keys = list(self.model_state_dict().keys()) model_sd_keys = list(self.model_state_dict().keys())
memory_counter = None memory_counter = None
@@ -1100,12 +1216,16 @@ class ModelPatcher:
# if have cached weights for hooks, use it # if have cached weights for hooks, use it
cached_weights = self.cached_hook_patches.get(hooks, None) cached_weights = self.cached_hook_patches.get(hooks, None)
if cached_weights is not None: if cached_weights is not None:
model_sd_keys_set = set(model_sd_keys)
for key in cached_weights: for key in cached_weights:
if key not in model_sd_keys: if key not in model_sd_keys:
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}") logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
continue continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter) self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
model_sd_keys_set.remove(key)
self.unpatch_hooks(model_sd_keys_set)
else: else:
self.unpatch_hooks()
relevant_patches = self.get_combined_hook_patches(hooks=hooks) relevant_patches = self.get_combined_hook_patches(hooks=hooks)
original_weights = None original_weights = None
if len(relevant_patches) > 0: if len(relevant_patches) > 0:
@@ -1116,6 +1236,8 @@ class ModelPatcher:
continue continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights, self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter) memory_counter=memory_counter)
else:
self.unpatch_hooks()
self.current_hooks = hooks self.current_hooks = hooks
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
@@ -1172,12 +1294,18 @@ class ModelPatcher:
del out_weight del out_weight
del weight del weight
def unpatch_hooks(self) -> None: def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
with self.use_ejected(): with self.use_ejected():
if len(self.hook_backup) == 0: if len(self.hook_backup) == 0:
self.current_hooks = None self.current_hooks = None
return return
keys = list(self.hook_backup.keys()) keys = list(self.hook_backup.keys())
if whitelist_keys_set:
for k in keys:
if k in whitelist_keys_set:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.pop(k)
else:
for k in keys: for k in keys:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))

176
comfy/multigpu.py Normal file
View File

@@ -0,0 +1,176 @@
from __future__ import annotations
import torch
import logging
from collections import namedtuple
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.utils
import comfy.patcher_extension
import comfy.model_management
class GPUOptions:
def __init__(self, device_index: int, relative_speed: float):
self.device_index = device_index
self.relative_speed = relative_speed
def clone(self):
return GPUOptions(self.device_index, self.relative_speed)
def create_dict(self):
return {
"relative_speed": self.relative_speed
}
class GPUOptionsGroup:
def __init__(self):
self.options: dict[int, GPUOptions] = {}
def add(self, info: GPUOptions):
self.options[info.device_index] = info
def clone(self):
c = GPUOptionsGroup()
for opt in self.options.values():
c.add(opt)
return c
def register(self, model: ModelPatcher):
opts_dict = {}
# get devices that are valid for this model
devices: list[torch.device] = [model.load_device]
for extra_model in model.get_additional_models_with_key("multigpu"):
extra_model: ModelPatcher
devices.append(extra_model.load_device)
# create dictionary with actual device mapped to its GPUOptions
device_opts_list: list[GPUOptions] = []
for device in devices:
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
opts_dict[device] = device_opts.create_dict()
device_opts_list.append(device_opts)
# make relative_speed relative to 1.0
min_speed = min([x.relative_speed for x in device_opts_list])
for value in opts_dict.values():
value['relative_speed'] /= min_speed
model.model_options['multigpu_options'] = opts_dict
def get_torch_device_list():
devices = ["default"]
for device in comfy.model_management.get_all_torch_devices():
device: torch.device
devices.append(str(device.index))
return devices
def get_device_from_str(device_str: str, throw_error_if_not_found=False):
if device_str == "default":
return comfy.model_management.get_torch_device()
for device in comfy.model_management.get_all_torch_devices():
device: torch.device
if str(device.index) == device_str:
return device
if throw_error_if_not_found:
raise Exception(f"Device with index '{device_str}' not found.")
logging.warning(f"Device with index '{device_str}' not found, using default device ({comfy.model_management.get_torch_device()}) instead.")
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
model = model.clone()
# check if multigpu is already prepared - get the load devices from them if possible to exclude
skip_devices = set()
multigpu_models = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
for mm in multigpu_models:
skip_devices.add(mm.load_device)
skip_devices = list(skip_devices)
extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
extra_devices = extra_devices[:max_gpus-1]
# exclude skipped devices
for skip in skip_devices:
if skip in extra_devices:
extra_devices.remove(skip)
# create new deepclones
if len(extra_devices) > 0:
for device in extra_devices:
device_patcher = None
if reuse_loaded:
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
for lm in loaded_models:
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
device_patcher = lm.clone()
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
break
if device_patcher is None:
device_patcher = model.deepclone_multigpu(new_load_device=device)
device_patcher.is_multigpu_base_clone = True
multigpu_models = model.get_additional_models_with_key("multigpu")
multigpu_models.append(device_patcher)
model.set_additional_models("multigpu", multigpu_models)
model.match_multigpu_clones()
if gpu_options is None:
gpu_options = GPUOptionsGroup()
gpu_options.register(model)
else:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
# persist skip_devices for use in sampling code
# if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options:
# model.model_options["multigpu_skip_devices"] = skip_devices
return model
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
opts_dict = model_options['multigpu_options']
devices = list(model_options['multigpu_clones'].keys())
speed_per_device = []
work_per_device = []
# get sum of each device's relative_speed
total_speed = 0.0
for opts in opts_dict.values():
total_speed += opts['relative_speed']
# get relative work for each device;
# obtained by w = (W*r)/R
for device in devices:
relative_speed = opts_dict[device]['relative_speed']
relative_work = (total_work*relative_speed) / total_speed
speed_per_device.append(relative_speed)
work_per_device.append(relative_work)
# relative work must be expressed in whole numbers, but likely is a decimal;
# perform rounding while maintaining total sum equal to total work (sum of relative works)
work_per_device = round_preserved(work_per_device)
dict_work_per_device = {}
for device, relative_work in zip(devices, work_per_device):
dict_work_per_device[device] = relative_work
if not return_idle_time:
return LoadBalance(dict_work_per_device, None)
# divide relative work by relative speed to get estimated completion time of said work by each device;
# time here is relative and does not correspond to real-world units
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
# calculate relative time spent by the devices waiting on each other after their work is completed
idle_time = abs(min(completion_time) - max(completion_time))
# if need to compare work idle time, need to normalize to a common total work
if work_normalized:
idle_time *= (work_normalized/total_work)
return LoadBalance(dict_work_per_device, idle_time)
def round_preserved(values: list[float]):
'Round all values in a list, preserving the combined sum of values.'
# get floor of values; casting to int does it too
floored = [int(x) for x in values]
total_floored = sum(floored)
# get remainder to distribute
remainder = round(sum(values)) - total_floored
# pair values with fractional portions
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
# sort by fractional part in descending order
fractional.sort(key=lambda x: x[1], reverse=True)
# distribute the remainder
for i in range(remainder):
index = fractional[i][0]
floored[index] += 1
return floored

View File

@@ -17,6 +17,7 @@
""" """
import torch import torch
import logging
import comfy.model_management import comfy.model_management
from comfy.cli_args import args, PerformanceFeature from comfy.cli_args import args, PerformanceFeature
import comfy.float import comfy.float
@@ -308,6 +309,7 @@ class fp8_ops(manual_cast):
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
class scaled_fp8_op(manual_cast): class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear): class Linear(manual_cast.Linear):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@@ -358,7 +360,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None: if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8) return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
if ( if (
fp8_compute and fp8_compute and

View File

@@ -3,6 +3,8 @@ from typing import Callable
class CallbacksMP: class CallbacksMP:
ON_CLONE = "on_clone" ON_CLONE = "on_clone"
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
ON_LOAD = "on_load_after" ON_LOAD = "on_load_after"
ON_DETACH = "on_detach_after" ON_DETACH = "on_detach_after"
ON_CLEANUP = "on_cleanup" ON_CLEANUP = "on_cleanup"

View File

@@ -1,7 +1,9 @@
from __future__ import annotations from __future__ import annotations
import torch
import uuid import uuid
import comfy.model_management import comfy.model_management
import comfy.conds import comfy.conds
import comfy.model_patcher
import comfy.utils import comfy.utils
import comfy.hooks import comfy.hooks
import comfy.patcher_extension import comfy.patcher_extension
@@ -104,16 +106,57 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'): if hasattr(m, 'cleanup'):
m.cleanup() m.cleanup()
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) == 0:
return
extra_devices = [x.load_device for x in multigpu_models]
# handle controlnets
controlnets: set[ControlBase] = set()
for k in conds:
for kk in conds[k]:
if 'control' in kk:
controlnets.add(kk['control'])
if len(controlnets) > 0:
# first, unload all controlnet clones
for cnet in list(controlnets):
cnet_models = cnet.get_models()
for cm in cnet_models:
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
# next, make sure each controlnet has a deepclone for all relevant devices
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
for device in extra_devices:
if device not in curr_cnet.multigpu_clones:
curr_cnet.deepclone_multigpu(device, autoregister=True)
curr_cnet = curr_cnet.previous_controlnet
# since all device clones are now present, recreate the linked list for cloned cnets per device
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
prev_cnet = curr_cnet.previous_controlnet
for device in extra_devices:
device_cnet = curr_cnet.get_instance_for_device(device)
prev_device_cnet = None
if prev_cnet is not None:
prev_device_cnet = prev_cnet.get_instance_for_device(device)
device_cnet.set_previous_controlnet(prev_device_cnet)
curr_cnet = prev_cnet
# potentially handle gligen - since not widely used, ignored for now
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
real_model: BaseModel = None model.match_multigpu_clones()
preprocess_multigpu_conds(conds, model, model_options)
models, inference_memory = get_additional_models(conds, model.model_dtype()) models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options) models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update? models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required) comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
real_model = model.model real_model: BaseModel = model.model
return real_model, conds, models return real_model, conds, models
@@ -126,7 +169,7 @@ def cleanup_models(conds, models):
cleanup_additional_models(set(control_cleanup)) cleanup_additional_models(set(control_cleanup))
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
''' '''
Registers hooks from conds. Registers hooks from conds.
''' '''
@@ -159,3 +202,18 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False) copy_dict1=False)
return to_load_options return to_load_options
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
'''
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
'''
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
if len(multigpu_patchers) > 0:
multigpu_dict: dict[torch.device, ModelPatcher] = {}
multigpu_dict[model_patcher.load_device] = model_patcher
for x in multigpu_patchers:
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
multigpu_dict[x.load_device] = x
model_options["multigpu_clones"] = multigpu_dict
return multigpu_patchers

View File

@@ -1,4 +1,6 @@
from __future__ import annotations from __future__ import annotations
import comfy.model_management
from .k_diffusion import sampling as k_diffusion_sampling from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
from typing import TYPE_CHECKING, Callable, NamedTuple from typing import TYPE_CHECKING, Callable, NamedTuple
@@ -18,6 +20,13 @@ import comfy.patcher_extension
import comfy.hooks import comfy.hooks
import scipy.stats import scipy.stats
import numpy import numpy
import threading
def add_area_dims(area, num_dims):
while (len(area) // 2) < num_dims:
area = [2147483648] + area[:len(area) // 2] + [0] + area[len(area) // 2:]
return area
def get_area_and_mult(conds, x_in, timestep_in): def get_area_and_mult(conds, x_in, timestep_in):
dims = tuple(x_in.shape[2:]) dims = tuple(x_in.shape[2:])
@@ -34,8 +43,9 @@ def get_area_and_mult(conds, x_in, timestep_in):
return None return None
if 'area' in conds: if 'area' in conds:
area = list(conds['area']) area = list(conds['area'])
while (len(area) // 2) < len(dims): area = add_area_dims(area, len(dims))
area = [2147483648] + area[:len(area) // 2] + [0] + area[len(area) // 2:] if (len(area) // 2) > len(dims):
area = area[:len(dims)] + area[len(area) // 2:(len(area) // 2) + len(dims)]
if 'strength' in conds: if 'strength' in conds:
strength = conds['strength'] strength = conds['strength']
@@ -53,7 +63,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
if "mask_strength" in conds: if "mask_strength" in conds:
mask_strength = conds["mask_strength"] mask_strength = conds["mask_strength"]
mask = conds['mask'] mask = conds['mask']
assert(mask.shape[1:] == x_in.shape[2:]) assert (mask.shape[1:] == x_in.shape[2:])
mask = mask[:input_x.shape[0]] mask = mask[:input_x.shape[0]]
if area is not None: if area is not None:
@@ -67,16 +77,17 @@ def get_area_and_mult(conds, x_in, timestep_in):
mult = mask * strength mult = mask * strength
if 'mask' not in conds and area is not None: if 'mask' not in conds and area is not None:
rr = 8 fuzz = 8
for i in range(len(dims)): for i in range(len(dims)):
rr = min(fuzz, mult.shape[2 + i] // 4)
if area[len(dims) + i] != 0: if area[len(dims) + i] != 0:
for t in range(rr): for t in range(rr):
m = mult.narrow(i + 2, t, 1) m = mult.narrow(i + 2, t, 1)
m *= ((1.0/rr) * (t + 1)) m *= ((1.0 / rr) * (t + 1))
if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]: if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]:
for t in range(rr): for t in range(rr):
m = mult.narrow(i + 2, area[i] - 1 - t, 1) m = mult.narrow(i + 2, area[i] - 1 - t, 1)
m *= ((1.0/rr) * (t + 1)) m *= ((1.0 / rr) * (t + 1))
conditioning = {} conditioning = {}
model_conds = conds["model_conds"] model_conds = conds["model_conds"]
@@ -132,7 +143,7 @@ def can_concat_cond(c1, c2):
return cond_equal_size(c1.conditioning, c2.conditioning) return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list): def cond_cat(c_list, device=None):
temp = {} temp = {}
for x in c_list: for x in c_list:
for k in x: for k in x:
@@ -144,6 +155,8 @@ def cond_cat(c_list):
for k in temp: for k in temp:
conds = temp[k] conds = temp[k]
out[k] = conds[0].concat(conds[1:]) out[k] = conds[0].concat(conds[1:])
if device is not None and hasattr(out[k], 'to'):
out[k] = out[k].to(device)
return out return out
@@ -197,7 +210,9 @@ def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Ten
) )
return executor.execute(model, conds, x_in, timestep, model_options) return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
if 'multigpu_clones' in model_options:
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
out_conds = [] out_conds = []
out_counts = [] out_counts = []
# separate conds by matching hooks # separate conds by matching hooks
@@ -229,7 +244,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
if has_default_conds: if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep) model.current_patcher.prepare_state(timestep, model_options)
# run every hooked_to_run separately # run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items(): for hooks, to_run in hooked_to_run.items():
@@ -331,6 +346,190 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
return out_conds return out_conds
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
out_conds = []
out_counts = []
# separate conds by matching hooks
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
default_conds = []
has_default_conds = False
output_device = x_in.device
for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37)
cond = conds[i]
default_c = []
if cond is not None:
for x in cond:
if 'default' in x:
default_c.append(x)
has_default_conds = True
continue
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c)
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep, model_options)
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
total_conds = 0
for to_run in hooked_to_run.values():
total_conds += len(to_run)
conds_per_device = max(1, math.ceil(total_conds//len(devices)))
index_device = 0
current_device = devices[index_device]
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
while len(to_run) > 0:
current_device = devices[index_device % len(devices)]
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
# keep track of conds currently scheduled onto this device
batched_to_run_length = 0
for btr in batched_to_run:
batched_to_run_length += len(btr[1])
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
# make sure not over conds_per_device limit when creating temp batch
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(current_device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break
conds_to_batch = []
for x in to_batch:
conds_to_batch.append(to_run.pop(x))
batched_to_run_length += len(conds_to_batch)
batched_to_run.append((hooks, conds_to_batch))
if batched_to_run_length >= conds_per_device:
index_device += 1
thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond'])
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
model_current: BaseModel = model_options["multigpu_clones"][device].model
# run every hooked_to_run separately
with torch.no_grad():
for hooks, to_batch in batch_tuple:
input_x = []
mult = []
c = []
cond_or_uncond = []
uuids = []
area = []
control: ControlBase = None
patches = None
for x in to_batch:
o = x
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
uuids.append(p.uuid)
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x).to(device)
c = cond_cat(c, device=device)
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'],
copy_dict1=False)
if patches is not None:
# TODO: replace with merge_nested_dicts function
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
transformer_options["patches"] = cur_patches
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
transformer_options["sigmas"] = timestep
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
transformer_options["multigpu_thread_device"] = device
cast_transformer_options(transformer_options, device=device)
c['transformer_options'] = transformer_options
if control is not None:
device_control = control.get_instance_for_device(device)
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
else:
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
results: list[thread_result] = []
threads: list[threading.Thread] = []
for device, batch_tuple in device_batched_hooked_to_run.items():
new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results))
threads.append(new_thread)
new_thread.start()
for thread in threads:
thread.join()
for output, mult, area, batch_chunks, cond_or_uncond in results:
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
return out_conds
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.") logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options)) return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
@@ -551,24 +750,36 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.") logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.")
return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device) return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device)
def create_cond_with_same_area_if_none(conds, c): #TODO: handle dim != 2 def create_cond_with_same_area_if_none(conds, c):
if 'area' not in c: if 'area' not in c:
return return
def area_inside(a, area_cmp):
a = add_area_dims(a, len(area_cmp) // 2)
area_cmp = add_area_dims(area_cmp, len(a) // 2)
a_l = len(a) // 2
area_cmp_l = len(area_cmp) // 2
for i in range(min(a_l, area_cmp_l)):
if a[a_l + i] < area_cmp[area_cmp_l + i]:
return False
for i in range(min(a_l, area_cmp_l)):
if (a[i] + a[a_l + i]) > (area_cmp[i] + area_cmp[area_cmp_l + i]):
return False
return True
c_area = c['area'] c_area = c['area']
smallest = None smallest = None
for x in conds: for x in conds:
if 'area' in x: if 'area' in x:
a = x['area'] a = x['area']
if c_area[2] >= a[2] and c_area[3] >= a[3]: if area_inside(c_area, a):
if a[0] + a[2] >= c_area[0] + c_area[2]:
if a[1] + a[3] >= c_area[1] + c_area[3]:
if smallest is None: if smallest is None:
smallest = x smallest = x
elif 'area' not in smallest: elif 'area' not in smallest:
smallest = x smallest = x
else: else:
if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]: if math.prod(smallest['area'][:len(smallest['area']) // 2]) > math.prod(a[:len(a) // 2]):
smallest = x smallest = x
else: else:
if smallest is None: if smallest is None:
@@ -616,6 +827,8 @@ def pre_run_control(model, conds):
percent_to_timestep_function = lambda a: s.percent_to_sigma(a) percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x: if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function) x['control'].pre_run(model, percent_to_timestep_function)
for device_cnet in x['control'].multigpu_clones.values():
device_cnet.pre_run(model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = [] cond_cnets = []
@@ -690,7 +903,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation"] "gradient_estimation", "er_sde"]
class KSAMPLER(Sampler): class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}): def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
@@ -858,7 +1071,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
to_load_options = model_options.get("to_load_options", None) to_load_options = model_options.get("to_load_options", None)
if to_load_options is None: if to_load_options is None:
return return
cast_transformer_options(to_load_options, device, dtype)
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
casts = [] casts = []
if device is not None: if device is not None:
casts.append(device) casts.append(device)
@@ -867,18 +1082,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# if nothing to apply, do nothing # if nothing to apply, do nothing
if len(casts) == 0: if len(casts) == 0:
return return
# try to call .to on patches # try to call .to on patches
if "patches" in to_load_options: if "patches" in transformer_options:
patches = to_load_options["patches"] patches = transformer_options["patches"]
for name in patches: for name in patches:
patch_list = patches[name] patch_list = patches[name]
for i in range(len(patch_list)): for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"): if hasattr(patch_list[i], "to"):
for cast in casts: for cast in casts:
patch_list[i] = patch_list[i].to(cast) patch_list[i] = patch_list[i].to(cast)
if "patches_replace" in to_load_options: if "patches_replace" in transformer_options:
patches = to_load_options["patches_replace"] patches = transformer_options["patches_replace"]
for name in patches: for name in patches:
patch_list = patches[name] patch_list = patches[name]
for k in patch_list: for k in patch_list:
@@ -888,8 +1102,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# try to call .to on any wrappers/callbacks # try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"] wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks: for wc_name in wrappers_and_callbacks:
if wc_name in to_load_options: if wc_name in transformer_options:
wc: dict[str, list] = to_load_options[wc_name] wc: dict[str, list] = transformer_options[wc_name]
for wc_dict in wc.values(): for wc_dict in wc.values():
for wc_list in wc_dict.values(): for wc_list in wc_dict.values():
for i in range(len(wc_list)): for i in range(len(wc_list)):
@@ -897,7 +1111,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
for cast in casts: for cast in casts:
wc_list[i] = wc_list[i].to(cast) wc_list[i] = wc_list[i].to(cast)
class CFGGuider: class CFGGuider:
def __init__(self, model_patcher: ModelPatcher): def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher self.model_patcher = model_patcher
@@ -943,6 +1156,8 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device device = self.model_patcher.load_device
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
if denoise_mask is not None: if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
@@ -953,9 +1168,13 @@ class CFGGuider:
try: try:
self.model_patcher.pre_run() self.model_patcher.pre_run()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally: finally:
self.model_patcher.cleanup() self.model_patcher.cleanup()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model del self.inner_model

View File

@@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
import json
import torch import torch
from enum import Enum from enum import Enum
import logging import logging
@@ -134,8 +135,8 @@ class CLIP:
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):
self.layer_idx = layer_idx self.layer_idx = layer_idx
def tokenize(self, text, return_word_ids=False): def tokenize(self, text, return_word_ids=False, **kwargs):
return self.tokenizer.tokenize_with_weights(text, return_word_ids) return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
def add_hooks_to_dict(self, pooled_dict: dict[str]): def add_hooks_to_dict(self, pooled_dict: dict[str]):
if self.apply_hooks_to_conds: if self.apply_hooks_to_conds:
@@ -249,7 +250,7 @@ class CLIP:
return self.patcher.get_key_patches() return self.patcher.get_key_patches()
class VAE: class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None): def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd) sd = diffusers_convert.convert_vae_state_dict(sd)
@@ -357,7 +358,12 @@ class VAE:
version = 0 version = 0
elif tensor_conv1.shape[0] == 1024: elif tensor_conv1.shape[0] == 1024:
version = 1 version = 1
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version) if "encoder.down_blocks.1.conv.conv.bias" in sd:
version = 2
vae_config = None
if metadata is not None and "config" in metadata:
vae_config = json.loads(metadata["config"]).get("vae", None)
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config)
self.latent_channels = 128 self.latent_channels = 128
self.latent_dim = 3 self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
@@ -434,6 +440,10 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels): def vae_encode_crop_pixels(self, pixels):
downscale_ratio = self.spacial_compression_encode() downscale_ratio = self.spacial_compression_encode()
@@ -489,6 +499,7 @@ class VAE:
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in): def decode(self, samples_in):
self.throw_exception_if_invalid()
pixel_samples = None pixel_samples = None
try: try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
@@ -519,6 +530,7 @@ class VAE:
return pixel_samples return pixel_samples
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used) model_management.load_models_gpu([self.patcher], memory_required=memory_used)
dims = samples.ndim - 2 dims = samples.ndim - 2
@@ -547,6 +559,7 @@ class VAE:
return output.movedim(1, -1) return output.movedim(1, -1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1) pixel_samples = pixel_samples.movedim(-1, 1)
if self.latent_dim == 3 and pixel_samples.ndim < 5: if self.latent_dim == 3 and pixel_samples.ndim < 5:
@@ -579,6 +592,7 @@ class VAE:
return samples return samples
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
dims = self.latent_dim dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1) pixel_samples = pixel_samples.movedim(-1, 1)
@@ -873,13 +887,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae) return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
sd = comfy.utils.load_torch_file(ckpt_path) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
if out is None: if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
return out return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}): def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
clip = None clip = None
clipvision = None clipvision = None
vae = None vae = None
@@ -891,9 +905,14 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device() load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix) model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None: if model_config is None:
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
if diffusion_model is None:
return None return None
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
unet_weight_dtype = list(model_config.supported_inference_dtypes) unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None: if model_config.scaled_fp8 is not None:
@@ -920,7 +939,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_vae: if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd) vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd) vae = VAE(sd=vae_sd, metadata=metadata)
if output_clip: if output_clip:
clip_target = model_config.clip_target(state_dict=sd) clip_target = model_config.clip_target(state_dict=sd)

View File

@@ -158,71 +158,93 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_idx = self.options_default[1] self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2] self.return_projected_pooled = self.options_default[2]
def set_up_textual_embeddings(self, tokens, current_embeds): def process_tokens(self, tokens, device):
out_tokens = []
next_new_token = token_dict_size = current_embeds.weight.shape[0]
embedding_weights = []
for x in tokens:
tokens_temp = []
for y in x:
if isinstance(y, numbers.Integral):
tokens_temp += [int(y)]
else:
if y.shape[0] == current_embeds.weight.shape[1]:
embedding_weights += [y]
tokens_temp += [next_new_token]
next_new_token += 1
else:
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1]))
while len(tokens_temp) < len(x):
tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp]
n = token_dict_size
if len(embedding_weights) > 0:
new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight
for x in embedding_weights:
new_embedding.weight[n] = x
n += 1
self.transformer.set_input_embeddings(new_embedding)
processed_tokens = []
for x in out_tokens:
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
return processed_tokens
def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
attention_mask = None
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
attention_mask = torch.zeros_like(tokens)
end_token = self.special_tokens.get("end", None) end_token = self.special_tokens.get("end", None)
if end_token is None: if end_token is None:
cmp_token = self.special_tokens.get("pad", -1) cmp_token = self.special_tokens.get("pad", -1)
else: else:
cmp_token = end_token cmp_token = end_token
for x in range(attention_mask.shape[0]): embeds_out = []
for y in range(attention_mask.shape[1]): attention_masks = []
attention_mask[x, y] = 1 num_tokens = []
if tokens[x, y] == cmp_token:
for x in tokens:
attention_mask = []
tokens_temp = []
other_embeds = []
eos = False
index = 0
for y in x:
if isinstance(y, numbers.Integral):
if eos:
attention_mask.append(0)
else:
attention_mask.append(1)
token = int(y)
tokens_temp += [token]
if not eos and token == cmp_token:
if end_token is None: if end_token is None:
attention_mask[x, y] = 0 attention_mask[-1] = 0
break eos = True
else:
other_embeds.append((index, y))
index += 1
tokens_embed = torch.tensor([tokens_temp], device=device, dtype=torch.long)
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
index = 0
pad_extra = 0
for o in other_embeds:
emb = o[1]
if torch.is_tensor(emb):
emb = {"type": "embedding", "data": emb}
emb_type = emb.get("type", None)
if emb_type == "embedding":
emb = emb.get("data", None)
else:
if hasattr(self.transformer, "preprocess_embed"):
emb = self.transformer.preprocess_embed(emb, device=device)
else:
emb = None
if emb is None:
index += -1
continue
ind = index + o[0]
emb = emb.view(1, -1, emb.shape[-1]).to(device=device, dtype=torch.float32)
emb_shape = emb.shape[1]
if emb.shape[-1] == tokens_embed.shape[-1]:
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
index += emb_shape - 1
else:
index += -1
pad_extra += emb_shape
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1]))
if pad_extra > 0:
padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32)
tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1)
attention_mask = attention_mask + [0] * pad_extra
embeds_out.append(tokens_embed)
attention_masks.append(attention_mask)
num_tokens.append(sum(attention_mask))
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
def forward(self, tokens):
device = self.transformer.get_input_embeddings().weight.device
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
attention_mask_model = None attention_mask_model = None
if self.enable_attention_masks: if self.enable_attention_masks:
attention_mask_model = attention_mask attention_mask_model = attention_mask
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last": if self.layer == "last":
z = outputs[0].float() z = outputs[0].float()
@@ -482,7 +504,7 @@ class SDTokenizer:
return (embed, leftover) return (embed, leftover)
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
''' '''
Takes a prompt and converts it to a list of (token, weight, word id) elements. Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors. Tokens can both be integer tokens and pre computed CLIP tensors.
@@ -596,7 +618,7 @@ class SD1Tokenizer:
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer) tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)) setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids) out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
return out return out

View File

@@ -26,7 +26,7 @@ class SDXLTokenizer:
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)

View File

@@ -762,7 +762,7 @@ class LTXV(supported_models_base.BASE):
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.LTXV latent_format = latent_formats.LTXV
memory_usage_factor = 2.7 memory_usage_factor = 5.5 # TODO: img2vid is about 2x vs txt2vid
supported_inference_dtypes = [torch.bfloat16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float32]
@@ -826,6 +826,26 @@ class HunyuanVideo(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref)) hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
class HunyuanVideoI2V(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"in_channels": 33,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideoI2V(self, device=device)
return out
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"in_channels": 32,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideoSkyreelsI2V(self, device=device)
return out
class CosmosT2V(supported_models_base.BASE): class CosmosT2V(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "cosmos", "image_model": "cosmos",
@@ -911,7 +931,7 @@ class WAN21_T2V(supported_models_base.BASE):
memory_usage_factor = 1.0 memory_usage_factor = 1.0
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."] vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]
@@ -939,6 +959,6 @@ class WAN21_I2V(WAN21_T2V):
out = model_base.WAN21(self, image_to_video=True, device=device) out = model_base.WAN21(self, image_to_video=True, device=device)
return out return out
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@@ -93,7 +93,10 @@ class BertEmbeddings(torch.nn.Module):
self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device) self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, input_tokens, token_type_ids=None, dtype=None): def forward(self, input_tokens, embeds=None, token_type_ids=None, dtype=None):
if embeds is not None:
x = embeds
else:
x = self.word_embeddings(input_tokens, out_dtype=dtype) x = self.word_embeddings(input_tokens, out_dtype=dtype)
x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x) x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x)
if token_type_ids is not None: if token_type_ids is not None:
@@ -113,8 +116,8 @@ class BertModel_(torch.nn.Module):
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations) self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations) self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
x = self.embeddings(input_tokens, dtype=dtype) x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
mask = None mask = None
if attention_mask is not None: if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])

View File

@@ -18,7 +18,7 @@ class FluxTokenizer:
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)

View File

@@ -4,6 +4,7 @@ import comfy.text_encoders.llama
from transformers import LlamaTokenizerFast from transformers import LlamaTokenizerFast
import torch import torch
import os import os
import numbers
def llama_detect(state_dict, prefix=""): def llama_detect(state_dict, prefix=""):
@@ -22,7 +23,7 @@ def llama_detect(state_dict, prefix=""):
class LLAMA3Tokenizer(sd1_clip.SDTokenizer): class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256): def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length) super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length)
class LLAMAModel(sd1_clip.SDClipModel): class LLAMAModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
@@ -38,15 +39,26 @@ class HunyuanVideoTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
out = {} out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
llama_text = "{}{}".format(self.llama_template, text) if llama_template is None:
out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids) llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids)
embed_count = 0
for r in llama_text_tokens:
for i in range(len(r)):
if r[i][0] == 128257:
if image_embeds is not None and embed_count < image_embeds.shape[0]:
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image", "image_interleave": image_interleave},) + r[i][1:]
embed_count += 1
out["llama"] = llama_text_tokens
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
@@ -80,20 +92,51 @@ class HunyuanVideoClipModel(torch.nn.Module):
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama) llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
template_end = 0 template_end = 0
for i, v in enumerate(token_weight_pairs_llama[0]): extra_template_end = 0
if v[0] == 128007: # <|end_header_id|> extra_sizes = 0
template_end = i user_end = 9999999999999
images = []
tok_pairs = token_weight_pairs_llama[0]
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 128006:
if tok_pairs[i + 1][0] == 882:
if tok_pairs[i + 2][0] == 128007:
template_end = i + 2
user_end = -1
if elem == 128009 and user_end == -1:
user_end = i + 1
else:
if elem.get("original_type") == "image":
elem_size = elem.get("data").shape[0]
if template_end > 0:
if user_end == -1:
extra_template_end += elem_size - 1
else:
image_start = i + extra_sizes
image_end = i + elem_size + extra_sizes
images.append((image_start, image_end, elem.get("image_interleave", 1)))
extra_sizes += elem_size - 1
if llama_out.shape[1] > (template_end + 2): if llama_out.shape[1] > (template_end + 2):
if token_weight_pairs_llama[0][template_end + 1][0] == 271: if tok_pairs[template_end + 1][0] == 271:
template_end += 2 template_end += 2
llama_out = llama_out[:, template_end:] llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:] llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]): if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
if len(images) > 0:
out = []
for i in images:
out.append(llama_out[:, i[0]: i[1]: i[2]])
llama_output = torch.cat(out + [llama_output], dim=1)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return llama_out, l_pooled, llama_extra_out return llama_output, l_pooled, llama_extra_out
def load_sd(self, sd): def load_sd(self, sd):
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd: if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:

View File

@@ -37,7 +37,7 @@ class HyditTokenizer:
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory) self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory) self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids) out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids) out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)

View File

@@ -241,7 +241,10 @@ class Llama2_(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
if embeds is not None:
x = embeds
else:
x = self.embed_tokens(x, out_dtype=dtype) x = self.embed_tokens(x, out_dtype=dtype)
if self.normalize_in: if self.normalize_in:

View File

@@ -43,7 +43,7 @@ class SD3Tokenizer:
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)

View File

@@ -239,8 +239,11 @@ class T5(torch.nn.Module):
def set_input_embeddings(self, embeddings): def set_input_embeddings(self, embeddings):
self.shared = embeddings self.shared = embeddings
def forward(self, input_ids, *args, **kwargs): def forward(self, input_ids, attention_mask, embeds=None, num_tokens=None, **kwargs):
if input_ids is None:
x = embeds
else:
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32)) x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]: if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
x = torch.nan_to_num(x) #Fix for fp8 T5 base x = torch.nan_to_num(x) #Fix for fp8 T5 base
return self.encoder(x, *args, **kwargs) return self.encoder(x, attention_mask=attention_mask, **kwargs)

View File

@@ -46,12 +46,18 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
else: else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def load_torch_file(ckpt, safe_load=False, device=None): def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None: if device is None:
device = torch.device("cpu") device = torch.device("cpu")
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try: try:
sd = safetensors.torch.load_file(ckpt, device=device.type) with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
sd[k] = f.get_tensor(k)
if return_metadata:
metadata = f.metadata()
except Exception as e: except Exception as e:
if len(e.args) > 0: if len(e.args) > 0:
message = e.args[0] message = e.args[0]
@@ -77,7 +83,7 @@ def load_torch_file(ckpt, safe_load=False, device=None):
sd = pl_sd sd = pl_sd
else: else:
sd = pl_sd sd = pl_sd
return sd return (sd, metadata) if return_metadata else sd
def save_torch_file(sd, ckpt, metadata=None): def save_torch_file(sd, ckpt, metadata=None):
if metadata is not None: if metadata is not None:

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import torchaudio import torchaudio
import torch import torch
import comfy.model_management import comfy.model_management
@@ -10,6 +12,7 @@ import random
import hashlib import hashlib
import node_helpers import node_helpers
from comfy.cli_args import args from comfy.cli_args import args
from comfy.comfy_types import FileLocator
class EmptyLatentAudio: class EmptyLatentAudio:
def __init__(self): def __init__(self):
@@ -164,7 +167,7 @@ class SaveAudio:
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results = list() results: list[FileLocator] = []
metadata = {} metadata = {}
if not args.disable_metadata: if not args.disable_metadata:

View File

@@ -454,7 +454,7 @@ class SamplerCustom:
return {"required": return {"required":
{"model": ("MODEL",), {"model": ("MODEL",),
"add_noise": ("BOOLEAN", {"default": True}), "add_noise": ("BOOLEAN", {"default": True}),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"positive": ("CONDITIONING", ), "positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING", ),
@@ -605,8 +605,14 @@ class DisableNoise:
class RandomNoise(DisableNoise): class RandomNoise(DisableNoise):
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required":{ return {
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "required": {
"noise_seed": ("INT", {
"default": 0,
"min": 0,
"max": 0xffffffffffffffff,
"control_after_generate": True,
}),
} }
} }

View File

@@ -1,4 +1,5 @@
import nodes import nodes
import node_helpers
import torch import torch
import comfy.model_management import comfy.model_management
@@ -38,7 +39,83 @@ class EmptyHunyuanLatentVideo:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, ) return ({"samples":latent}, )
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
class TextEncodeHunyuanVideo_ImageToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, clip_vision_output, prompt, image_interleave):
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
return (clip.encode_from_tokens_scheduled(tokens), )
class HunyuanImageToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
},
"optional": {"start_image": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
out_latent = {}
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image)
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
if guidance_type == "v1 (concat)":
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
else:
cond = {'guiding_frame_index': 0}
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
out_latent["noise_mask"] = mask
positive = node_helpers.conditioning_set_values(positive, cond)
out_latent["samples"] = latent
return (positive, out_latent)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
"HunyuanImageToVideo": HunyuanImageToVideo,
} }

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import nodes import nodes
import folder_paths import folder_paths
from comfy.cli_args import args from comfy.cli_args import args
@@ -9,6 +11,8 @@ import numpy as np
import json import json
import os import os
from comfy.comfy_types import FileLocator
MAX_RESOLUTION = nodes.MAX_RESOLUTION MAX_RESOLUTION = nodes.MAX_RESOLUTION
class ImageCrop: class ImageCrop:
@@ -99,7 +103,7 @@ class SaveAnimatedWEBP:
method = self.methods.get(method) method = self.methods.get(method)
filename_prefix += self.prefix_append filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list() results: list[FileLocator] = []
pil_images = [] pil_images = []
for image in images: for image in images:
i = 255. * image.cpu().numpy() i = 255. * image.cpu().numpy()

View File

@@ -19,8 +19,6 @@ class Load3D():
"image": ("LOAD_3D", {}), "image": ("LOAD_3D", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}} }}
RETURN_TYPES = ("IMAGE", "MASK", "STRING") RETURN_TYPES = ("IMAGE", "MASK", "STRING")
@@ -55,8 +53,6 @@ class Load3DAnimation():
"image": ("LOAD_3D_ANIMATION", {}), "image": ("LOAD_3D_ANIMATION", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}} }}
RETURN_TYPES = ("IMAGE", "MASK", "STRING") RETURN_TYPES = ("IMAGE", "MASK", "STRING")
@@ -82,8 +78,6 @@ class Preview3D():
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}), "model_file": ("STRING", {"default": "", "multiline": False}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}} }}
OUTPUT_NODE = True OUTPUT_NODE = True
@@ -102,8 +96,6 @@ class Preview3DAnimation():
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}), "model_file": ("STRING", {"default": "", "multiline": False}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}} }}
OUTPUT_NODE = True OUTPUT_NODE = True

View File

@@ -1,9 +1,14 @@
import io
import nodes import nodes
import node_helpers import node_helpers
import torch import torch
import comfy.model_management import comfy.model_management
import comfy.model_sampling import comfy.model_sampling
import comfy.utils
import math import math
import numpy as np
import av
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
class EmptyLTXVLatentVideo: class EmptyLTXVLatentVideo:
@classmethod @classmethod
@@ -33,7 +38,6 @@ class LTXVImgToVideo:
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 1.0, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."})
}} }}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
@@ -42,16 +46,219 @@ class LTXVImgToVideo:
CATEGORY = "conditioning/video_models" CATEGORY = "conditioning/video_models"
FUNCTION = "generate" FUNCTION = "generate"
def generate(self, positive, negative, image, vae, width, height, length, batch_size, image_noise_scale): def generate(self, positive, negative, image, vae, width, height, length, batch_size):
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3] encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
latent[:, :, :t.shape[2]] = t latent[:, :, :t.shape[2]] = t
return (positive, negative, {"samples": latent}, )
conditioning_latent_frames_mask = torch.ones(
(batch_size, 1, latent.shape[2], 1, 1),
dtype=torch.float32,
device=latent.device,
)
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, )
def conditioning_get_any_value(conditioning, key, default=None):
for t in conditioning:
if key in t[1]:
return t[1][key]
return default
def get_noise_mask(latent):
noise_mask = latent.get("noise_mask", None)
latent_image = latent["samples"]
if noise_mask is None:
batch_size, _, latent_length, _, _ = latent_image.shape
noise_mask = torch.ones(
(batch_size, 1, latent_length, 1, 1),
dtype=torch.float32,
device=latent_image.device,
)
else:
noise_mask = noise_mask.clone()
return noise_mask
def get_keyframe_idxs(cond):
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
if keyframe_idxs is None:
return None, 0
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
return keyframe_idxs, num_keyframes
class LTXVAddGuide:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE",),
"latent": ("LATENT",),
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames."
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
"tooltip": "Frame index to start the conditioning at. For single-frame images or "
"videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
"frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
"the nearest multiple of 8. Negative values are counted from the end of the video."}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def __init__(self):
self._num_prefix_frames = 2
self._patchifier = SymmetricPatchifier(1)
def encode(self, vae, latent_width, latent_height, images, scale_factors):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
return encode_pixels, t
def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors
_, num_keyframes = get_keyframe_idxs(cond)
latent_count = latent_length - num_keyframes
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
if guide_length > 1:
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
return frame_idx, latent_idx
def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors):
keyframe_idxs, _ = get_keyframe_idxs(cond)
_, latent_coords = self._patchifier.patchify(guiding_latent)
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, True)
pixel_coords[:, 0] += frame_idx
if keyframe_idxs is None:
keyframe_idxs = pixel_coords
else:
keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1),
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,
)
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
noise_mask = torch.cat([noise_mask, mask], dim=2)
return positive, negative, latent_image, noise_mask
def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength):
cond_length = guiding_latent.shape[2]
assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence."
mask = torch.full(
(noise_mask.shape[0], 1, cond_length, 1, 1),
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,
)
latent_image = latent_image.clone()
noise_mask = noise_mask.clone()
latent_image[:, :, latent_idx : latent_idx + cond_length] = guiding_latent
noise_mask[:, :, latent_idx : latent_idx + cond_length] = mask
return latent_image, noise_mask
def generate(self, positive, negative, vae, latent, image, frame_idx, strength):
scale_factors = vae.downscale_index_formula
latent_image = latent["samples"]
noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
positive, negative, latent_image, noise_mask = self.append_keyframe(
positive,
negative,
frame_idx,
latent_image,
noise_mask,
t[:, :, :num_prefix_frames],
strength,
scale_factors,
)
latent_idx += num_prefix_frames
t = t[:, :, num_prefix_frames:]
if t.shape[2] == 0:
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
latent_image, noise_mask = self.replace_latent_frames(
latent_image,
noise_mask,
t,
latent_idx,
strength,
)
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
class LTXVCropGuides:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent": ("LATENT",),
}
}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
CATEGORY = "conditioning/video_models"
FUNCTION = "crop"
def __init__(self):
self._patchifier = SymmetricPatchifier(1)
def crop(self, positive, negative, latent):
latent_image = latent["samples"].clone()
noise_mask = get_noise_mask(latent)
_, num_keyframes = get_keyframe_idxs(positive)
if num_keyframes == 0:
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
latent_image = latent_image[:, :, :-num_keyframes]
noise_mask = noise_mask[:, :, :-num_keyframes]
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
class LTXVConditioning: class LTXVConditioning:
@@ -174,6 +381,77 @@ class LTXVScheduler:
return (sigmas,) return (sigmas,)
def encode_single_frame(output_file, image_array: np.ndarray, crf):
container = av.open(output_file, "w", format="mp4")
try:
stream = container.add_stream(
"h264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
)
stream.height = image_array.shape[0]
stream.width = image_array.shape[1]
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
format="yuv420p"
)
container.mux(stream.encode(av_frame))
container.mux(stream.encode())
finally:
container.close()
def decode_single_frame(video_file):
container = av.open(video_file)
try:
stream = next(s for s in container.streams if s.type == "video")
frame = next(container.decode(stream))
finally:
container.close()
return frame.to_ndarray(format="rgb24")
def preprocess(image: torch.Tensor, crf=29):
if crf == 0:
return image
image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy()
with io.BytesIO() as output_file:
encode_single_frame(output_file, image_array, crf)
video_bytes = output_file.getvalue()
with io.BytesIO(video_bytes) as video_file:
image_array = decode_single_frame(video_file)
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
return tensor
class LTXVPreprocess:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"img_compression": (
"INT",
{
"default": 35,
"min": 0,
"max": 100,
"tooltip": "Amount of compression to apply on image.",
},
),
}
}
FUNCTION = "preprocess"
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("output_image",)
CATEGORY = "image"
def preprocess(self, image, img_compression):
if img_compression > 0:
output_images = []
for i in range(image.shape[0]):
output_images.append(preprocess(image[i], img_compression))
return (torch.stack(output_images),)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo, "EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
@@ -181,4 +459,7 @@ NODE_CLASS_MAPPINGS = {
"ModelSamplingLTXV": ModelSamplingLTXV, "ModelSamplingLTXV": ModelSamplingLTXV,
"LTXVConditioning": LTXVConditioning, "LTXVConditioning": LTXVConditioning,
"LTXVScheduler": LTXVScheduler, "LTXVScheduler": LTXVScheduler,
"LTXVAddGuide": LTXVAddGuide,
"LTXVPreprocess": LTXVPreprocess,
"LTXVCropGuides": LTXVCropGuides,
} }

View File

@@ -0,0 +1,108 @@
from __future__ import annotations
from inspect import cleandoc
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.multigpu
from nodes import VAELoader
class VAELoaderDevice(VAELoader):
NodeId = "VAELoaderDevice"
NodeName = "Load VAE MultiGPU"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"vae_name": (cls.vae_list(), ),
"load_device": (comfy.multigpu.get_torch_device_list(), ),
}
}
FUNCTION = "load_vae_device"
CATEGORY = "advanced/multigpu/loaders"
def load_vae_device(self, vae_name, load_device: str):
device = comfy.multigpu.get_device_from_str(load_device)
return self.load_vae(vae_name, device)
class MultiGPUWorkUnitsNode:
"""
Prepares model to have sampling accelerated via splitting work units.
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
Other than those exceptions, this node can be placed in any order.
"""
NodeId = "MultiGPU_WorkUnits"
NodeName = "MultiGPU Work Units"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}),
},
"optional": {
"gpu_options": ("GPU_OPTIONS",)
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "init_multigpu"
CATEGORY = "advanced/multigpu"
DESCRIPTION = cleandoc(__doc__)
def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True)
return (model,)
class MultiGPUOptionsNode:
"""
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
"""
NodeId = "MultiGPU_Options"
NodeName = "MultiGPU Options"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"device_index": ("INT", {"default": 0, "min": 0, "max": 64}),
"relative_speed": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.01})
},
"optional": {
"gpu_options": ("GPU_OPTIONS",)
}
}
RETURN_TYPES = ("GPU_OPTIONS",)
FUNCTION = "create_gpu_options"
CATEGORY = "advanced/multigpu"
DESCRIPTION = cleandoc(__doc__)
def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
if not gpu_options:
gpu_options = comfy.multigpu.GPUOptionsGroup()
gpu_options.clone()
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
gpu_options.add(opt)
return (gpu_options,)
node_list = [
MultiGPUWorkUnitsNode,
MultiGPUOptionsNode,
VAELoaderDevice,
]
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
for node in node_list:
NODE_CLASS_MAPPINGS[node.NodeId] = node
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName

View File

@@ -1,9 +1,12 @@
from __future__ import annotations
import os import os
import av import av
import torch import torch
import folder_paths import folder_paths
import json import json
from fractions import Fraction from fractions import Fraction
from comfy.comfy_types import FileLocator
class SaveWEBM: class SaveWEBM:
@@ -62,7 +65,7 @@ class SaveWEBM:
container.mux(stream.encode()) container.mux(stream.encode())
container.close() container.close()
results = [{ results: list[FileLocator] = [{
"filename": file, "filename": file,
"subfolder": subfolder, "subfolder": subfolder,
"type": self.type "type": self.type

View File

@@ -4,6 +4,7 @@ import comfy.utils
import comfy.sd import comfy.sd
import folder_paths import folder_paths
import comfy_extras.nodes_model_merging import comfy_extras.nodes_model_merging
import node_helpers
class ImageOnlyCheckpointLoader: class ImageOnlyCheckpointLoader:
@@ -121,12 +122,38 @@ class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {} return {}
class ConditioningSetAreaPercentageVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"temporal": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"z": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, conditioning, width, height, temporal, x, y, z, strength):
c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", temporal, height, width, z, y, x),
"strength": strength,
"set_area_to_bounds": False})
return (c, )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance, "VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance, "VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave, "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
"ConditioningSetAreaPercentageVideo": ConditioningSetAreaPercentageVideo,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.18" __version__ = "0.3.26"

View File

@@ -634,6 +634,13 @@ def validate_inputs(prompt, item, validated):
continue continue
else: else:
try: try:
# Unwraps values wrapped in __value__ key. This is used to pass
# list widget value to execution, as by default list value is
# reserved to represent the connection between nodes.
if isinstance(val, dict) and "__value__" in val:
val = val["__value__"]
inputs[x] = val
if type_input == "INT": if type_input == "INT":
val = int(val) val = int(val)
inputs[x] = val inputs[x] = val

View File

@@ -139,6 +139,7 @@ from server import BinaryEventTypes
import nodes import nodes
import comfy.model_management import comfy.model_management
import comfyui_version import comfyui_version
import app.logger
def cuda_malloc_warning(): def cuda_malloc_warning():
@@ -295,9 +296,12 @@ def start_comfyui(asyncio_loop=None):
if __name__ == "__main__": if __name__ == "__main__":
# Running directly, just start ComfyUI. # Running directly, just start ComfyUI.
logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
event_loop, _, start_all_func = start_comfyui() event_loop, _, start_all_func = start_comfyui()
try: try:
event_loop.run_until_complete(start_all_func()) x = start_all_func()
app.logger.print_startup_warnings()
event_loop.run_until_complete(x)
except KeyboardInterrupt: except KeyboardInterrupt:
logging.info("\nStopped server") logging.info("\nStopped server")

View File

@@ -25,7 +25,7 @@ import comfy.sample
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
import comfy.controlnet import comfy.controlnet
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
import comfy.clip_vision import comfy.clip_vision
@@ -479,7 +479,7 @@ class SaveLatent:
file = f"{filename}_{counter:05}_.latent" file = f"{filename}_{counter:05}_.latent"
results = list() results: list[FileLocator] = []
results.append({ results.append({
"filename": file, "filename": file,
"subfolder": subfolder, "subfolder": subfolder,
@@ -489,7 +489,7 @@ class SaveLatent:
file = os.path.join(full_output_folder, file) file = os.path.join(full_output_folder, file)
output = {} output = {}
output["latent_tensor"] = samples["samples"] output["latent_tensor"] = samples["samples"].contiguous()
output["latent_format_version_0"] = torch.tensor([]) output["latent_format_version_0"] = torch.tensor([])
comfy.utils.save_torch_file(output, file, metadata=metadata) comfy.utils.save_torch_file(output, file, metadata=metadata)
@@ -763,13 +763,14 @@ class VAELoader:
CATEGORY = "loaders" CATEGORY = "loaders"
#TODO: scale factor? #TODO: scale factor?
def load_vae(self, vae_name): def load_vae(self, vae_name, device=None):
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
sd = self.load_taesd(vae_name) sd = self.load_taesd(vae_name)
else: else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd) vae = comfy.sd.VAE(sd=sd, device=device)
vae.throw_exception_if_invalid()
return (vae,) return (vae,)
class ControlNetLoader: class ControlNetLoader:
@@ -1519,7 +1520,7 @@ class KSampler:
return { return {
"required": { "required": {
"model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}), "model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "tooltip": "The random seed used for creating the noise."}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}), "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}),
@@ -1547,7 +1548,7 @@ class KSamplerAdvanced:
return {"required": return {"required":
{"model": ("MODEL",), {"model": ("MODEL",),
"add_noise": (["enable", "disable"], ), "add_noise": (["enable", "disable"], ),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
@@ -1785,14 +1786,7 @@ class LoadImageOutput(LoadImage):
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration." DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
EXPERIMENTAL = True EXPERIMENTAL = True
FUNCTION = "load_image_output" FUNCTION = "load_image"
def load_image_output(self, image):
return self.load_image(f"{image} [output]")
@classmethod
def VALIDATE_INPUTS(s, image):
return True
class ImageScale: class ImageScale:
@@ -2265,6 +2259,7 @@ def init_builtin_extra_nodes():
"nodes_mahiro.py", "nodes_mahiro.py",
"nodes_lt.py", "nodes_lt.py",
"nodes_hooks.py", "nodes_hooks.py",
"nodes_multigpu.py",
"nodes_load_3d.py", "nodes_load_3d.py",
"nodes_cosmos.py", "nodes_cosmos.py",
"nodes_video.py", "nodes_video.py",

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.18" version = "0.3.26"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@@ -1,4 +1,4 @@
comfyui-frontend-package==1.10.17 comfyui-frontend-package==1.12.14
torch torch
torchsde torchsde
torchvision torchvision

View File

@@ -70,7 +70,7 @@ def test_get_release_invalid_version(mock_provider):
def test_init_frontend_default(): def test_init_frontend_default():
version_string = DEFAULT_VERSION_STRING version_string = DEFAULT_VERSION_STRING
frontend_path = FrontendManager.init_frontend(version_string) frontend_path = FrontendManager.init_frontend(version_string)
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH assert frontend_path == FrontendManager.default_frontend_path()
def test_init_frontend_invalid_version(): def test_init_frontend_invalid_version():
@@ -84,24 +84,29 @@ def test_init_frontend_invalid_provider():
with pytest.raises(HTTPError): with pytest.raises(HTTPError):
FrontendManager.init_frontend_unsafe(version_string) FrontendManager.init_frontend_unsafe(version_string)
@pytest.fixture @pytest.fixture
def mock_os_functions(): def mock_os_functions():
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \ with (
patch('app.frontend_management.os.listdir') as mock_listdir, \ patch("app.frontend_management.os.makedirs") as mock_makedirs,
patch('app.frontend_management.os.rmdir') as mock_rmdir: patch("app.frontend_management.os.listdir") as mock_listdir,
patch("app.frontend_management.os.rmdir") as mock_rmdir,
):
mock_listdir.return_value = [] # Simulate empty directory mock_listdir.return_value = [] # Simulate empty directory
yield mock_makedirs, mock_listdir, mock_rmdir yield mock_makedirs, mock_listdir, mock_rmdir
@pytest.fixture @pytest.fixture
def mock_download(): def mock_download():
with patch('app.frontend_management.download_release_asset_zip') as mock: with patch("app.frontend_management.download_release_asset_zip") as mock:
mock.side_effect = Exception("Download failed") # Simulate download failure mock.side_effect = Exception("Download failed") # Simulate download failure
yield mock yield mock
def test_finally_block(mock_os_functions, mock_download, mock_provider): def test_finally_block(mock_os_functions, mock_download, mock_provider):
# Arrange # Arrange
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
version_string = 'test-owner/test-repo@1.0.0' version_string = "test-owner/test-repo@1.0.0"
# Act & Assert # Act & Assert
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -128,3 +133,42 @@ def test_parse_version_string_invalid():
version_string = "invalid" version_string = "invalid"
with pytest.raises(argparse.ArgumentTypeError): with pytest.raises(argparse.ArgumentTypeError):
FrontendManager.parse_version_string(version_string) FrontendManager.parse_version_string(version_string)
def test_init_frontend_default_with_mocks():
# Arrange
version_string = DEFAULT_VERSION_STRING
# Act
with (
patch("app.frontend_management.check_frontend_version") as mock_check,
patch.object(
FrontendManager, "default_frontend_path", return_value="/mocked/path"
),
):
frontend_path = FrontendManager.init_frontend(version_string)
# Assert
assert frontend_path == "/mocked/path"
mock_check.assert_called_once()
def test_init_frontend_fallback_on_error():
# Arrange
version_string = "test-owner/test-repo@1.0.0"
# Act
with (
patch.object(
FrontendManager, "init_frontend_unsafe", side_effect=Exception("Test error")
),
patch("app.frontend_management.check_frontend_version") as mock_check,
patch.object(
FrontendManager, "default_frontend_path", return_value="/default/path"
),
):
frontend_path = FrontendManager.init_frontend(version_string)
# Assert
assert frontend_path == "/default/path"
mock_check.assert_called_once()